PyTorch nn.ModuleList 和 nn.Sequential 区别

PyTorch nn.ModuleList 和 nn.Sequential 区别,第1张

PyTorch nn.ModuleList 和 nn.Sequential 区别

在学习经典 encode-decode 的 CNN 网络 U-Net 的 PyTorch 代码时,产生了如题的疑惑,在 Pytorch Forum上找到相关主题帖,记录之。

nn.Module,nn.ModuleList,nn.Sequential 都是 Class。

nn.Moudle

torch.nn.Module 是 base class for all neural network modules,任何一个 module 都要继承它。

nn.ModuleList

torch.nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module’s. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
  def __init__(self, input_size, num_layers, layers_size, output_size):
     super(LinearNet, self).__init__()

     self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
     self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
     self.linears.append(nn.Linear(layers_size, output_size)
nn.Sequential

torch.nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module’s) of that net. Here’s an example:

class Flatten(nn.Module):
  def forward(self, x):
    N, C, H, W = x.size() # read in N, C, H, W
    return x.view(N, -1)

simple_cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, stride=2),
            nn.ReLU(inplace=True),
            Flatten(), 
            nn.Linear(5408, 10),
          )
区别
  • 关于 torch.nn.Sequential,存储其中的 nn.Module 相互是一种 cascaded(处理有先后顺序的)的关系。而且,这种类型的 object 有 forward() 函数

  • 关于 torch.nn.ModuleList, 这种类型的 object 无 forward() 函数。存储其中的 nn.Module 无逻辑顺序关系。为什么不直接用 Python 自带的 list 类型存储 nn.Module 呢? 因为 Pytorch 能够 “识别” 在 nn.ModuleList 中的是 nn.Module ,Python 内置的 list 是不行的。如果对上文中举的例子 Class LinearNet 重新定义,用 Python 内置的 list 类型代替 nn.ModuleList ,在训练的时候 ,optimizer() 会报错: your model has no parameters。因为 Pytorch 不知道 在 Python 内置的 list 类型中是 parameters。但是使用 nn.ModuleList 不会报错。

参考资料

When should I use nn.ModuleList and when should I use nn.Sequential?

torch.nn.Module 官方文档

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5436875.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-11
下一篇 2022-12-11

发表评论

登录后才能评论

评论列表(0条)

保存