Torch.spmm只支持 sparse 在前,dense 在后的矩阵乘法,两个sparse相乘或者dense在前的乘法不支持,当然两个dense矩阵相乘是支持的。
import torch if __name__ == '__main__': indices = torch.tensor([[0,1], [0,1]]) values = torch.tensor([2,3]) shape = torch.Size((2,2)) s = torch.sparse.FloatTensor(indices,values,shape) print(s) d = torch.tensor([[1,2], [3,4]]) e = torch.tensor([[1, 2], [3, 4]]) # print(d) # print(torch.spmm(s,d))
错误的情况会报错:
RuntimeError: sparse tensors do not have strides
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)