PyTorch中inplace运算知识点总结(持续更新)
目录
*** 作含义
哪些 *** 作是in-place *** 作?
哪些 *** 作不是in-place?
PyTorch中的使用
深拷贝和浅拷贝
*** 作含义
可翻译为 原地 *** 作。对某个变量x进行运算时,其运算结果,如果保存在x所引用的相同内存地址上,则称该 *** 作为in-place。如果不是,则不是in-place。
在pytorch中经常加后缀“_”来代表原地 in-place operation,比如说.add_() 或者.scatter_()。python里面的+=,*=也是in-place operation。
官方文档:
Tensors — PyTorch Tutorials 1.11.0+cu102 documentation
What is `in-place operation`? - #2 by albanD - PyTorch Forums
哪些 *** 作是in-place *** 作?x += 1
x[i] = 1 #但是,此时x不是in-place
哪些 *** 作不是in-place?
x= x+1 # x+1产生一个新的内存地址,并将x的引用指向该新地址
PyTorch中的使用
- 可以使用函数 id() 检查。
- PyTorch layer中如果使用in-place,会导致无法反向传播。
import torch
import torch.nn as nn
class Net(mm.Module):
def __init__(self):
super(Net, self).__init()__
def forward(self, x):
for i in range(len(x)):
x[i] = 1
# 提示:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:
# Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)
深拷贝和浅拷贝
判断一个变量,其指向的内存地址是否发生了变化。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)