怎样判断一个 *** 作是否是可导的(Pytorch)

怎样判断一个 *** 作是否是可导的(Pytorch),第1张

怎样判断一个 *** 作是否是可导的(Pytorch)

在Pytorch中,看一个 *** 作是否可导,即经过这个 *** 作梯度是否还能顺利传递

 可以看到,经过+ *** 作后得到的z,仍能保持梯度的传递

而像torch.argmax(),  torch.eq() 这些 *** 作就不行了,这些 *** 作就是不可导的

即,遇到不可导的,你反向传播都会出问题,程序自己就会报错

 

而如果像soft argmax

import torch
import torch.nn as nn

def soft_argmax(x):
	"""
	Arguments: voxel patch in shape (batch_size, channel, H, W, depth)
	Return: 3D coordinates in shape (batch_size, channel, 3)
	"""
	# alpha is here to make the largest element really big, so it
	# would become very close to 1 after softmax
	alpha = 10000.0 
	N,C,L = x.shape
	soft_max = nn.functional.softmax(x*alpha,dim=2)
	soft_max = soft_max.view(x.shape)
	indices_kernel = torch.arange(start=0, end=L).unsqueeze(0)
	# indices_kernel = indices_kernel.view((H,W,D))
	# indices_kernel = indices_kernel.view(H,W)
	conv = soft_max*indices_kernel
	indices = conv.sum(2)
	# z = indices%D
	# y = (indices).floor()%W
	# x = (((indices).floor())/W).floor()%H
	# coords = torch.stack([x,y,z],dim=2)
	# coords = torch.stack([x,y],dim=2)
    #coords[0][0]代表第一个channel的最大点的坐标值
    #coords[0][1]代表第2个channel的最大点的坐标值
	return indices
 
if __name__ == "__main__":
	x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth)
	coords = soft_argmax(x)
    #coords是[b,c,2]
	print(coords)

  *** 作就是可导的

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5720724.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-18
下一篇 2022-12-18

发表评论

登录后才能评论

评论列表(0条)

保存