在Pytorch中,看一个 *** 作是否可导,即经过这个 *** 作梯度是否还能顺利传递。
可以看到,经过+ *** 作后得到的z,仍能保持梯度的传递
而像torch.argmax(), torch.eq() 这些 *** 作就不行了,这些 *** 作就是不可导的
即,遇到不可导的,你反向传播都会出问题,程序自己就会报错
而如果像soft argmax
import torch import torch.nn as nn def soft_argmax(x): """ Arguments: voxel patch in shape (batch_size, channel, H, W, depth) Return: 3D coordinates in shape (batch_size, channel, 3) """ # alpha is here to make the largest element really big, so it # would become very close to 1 after softmax alpha = 10000.0 N,C,L = x.shape soft_max = nn.functional.softmax(x*alpha,dim=2) soft_max = soft_max.view(x.shape) indices_kernel = torch.arange(start=0, end=L).unsqueeze(0) # indices_kernel = indices_kernel.view((H,W,D)) # indices_kernel = indices_kernel.view(H,W) conv = soft_max*indices_kernel indices = conv.sum(2) # z = indices%D # y = (indices).floor()%W # x = (((indices).floor())/W).floor()%H # coords = torch.stack([x,y,z],dim=2) # coords = torch.stack([x,y],dim=2) #coords[0][0]代表第一个channel的最大点的坐标值 #coords[0][1]代表第2个channel的最大点的坐标值 return indices if __name__ == "__main__": x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth) coords = soft_argmax(x) #coords是[b,c,2] print(coords)*** 作就是可导的
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)