masked_fill方法有两个参数,maske和value,mask是一个pytorch张量(Tensor),元素是布尔值,value是要填充的值,填充规则是mask中取值为True位置对应于待填充的相应位置用value填充。
import torch a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]]) print(a) print(a.size()) print("#############################################3") mask = torch.ByteTensor([[[1],[1],[0]],[[0],[1],[1]]]) print(mask.size()) b = a.masked_fill(mask, value=torch.tensor(-1e9)) print(b) print(b.size())
其实就是只要mask中的布尔值为True的话,就将待填充的对应的位置(此程序中的a)设置为value值
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)