在学习经典 encode-decode 的 CNN 网络 U-Net 的 PyTorch 代码时,产生了如题的疑惑,在 Pytorch Forum上找到相关主题帖,记录之。
nn.Module,nn.ModuleList,nn.Sequential 都是 Class。
nn.Moudletorch.nn.Module 是 base class for all neural network modules,任何一个 module 都要继承它。
nn.ModuleListtorch.nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module’s. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:
class LinearNet(nn.Module): def __init__(self, input_size, num_layers, layers_size, output_size): super(LinearNet, self).__init__() self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)]) self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)]) self.linears.append(nn.Linear(layers_size, output_size)nn.Sequential
torch.nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module’s) of that net. Here’s an example:
class Flatten(nn.Module): def forward(self, x): N, C, H, W = x.size() # read in N, C, H, W return x.view(N, -1) simple_cnn = nn.Sequential( nn.Conv2d(3, 32, kernel_size=7, stride=2), nn.ReLU(inplace=True), Flatten(), nn.Linear(5408, 10), )区别
-
关于 torch.nn.Sequential,存储其中的 nn.Module 相互是一种 cascaded(处理有先后顺序的)的关系。而且,这种类型的 object 有 forward() 函数。
-
关于 torch.nn.ModuleList, 这种类型的 object 无 forward() 函数。存储其中的 nn.Module 无逻辑顺序关系。为什么不直接用 Python 自带的 list 类型存储 nn.Module 呢? 因为 Pytorch 能够 “识别” 在 nn.ModuleList 中的是 nn.Module ,Python 内置的 list 是不行的。如果对上文中举的例子 Class LinearNet 重新定义,用 Python 内置的 list 类型代替 nn.ModuleList ,在训练的时候 ,optimizer() 会报错: your model has no parameters。因为 Pytorch 不知道 在 Python 内置的 list 类型中是 parameters。但是使用 nn.ModuleList 不会报错。
When should I use nn.ModuleList and when should I use nn.Sequential?
torch.nn.Module 官方文档
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)