PyTorch 中的 nn.ModuleList 和 nn.Sequential区别(入门笔记)

PyTorch 中的 nn.ModuleList 和 nn.Sequential区别(入门笔记),第1张

在学习到yolov3 spp源码时遇到的一些问题,其中之一是有关nn.ModuleList

部分引用:PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 - 知乎 (zhihu.com)

nn.ModuleList作用:搭建基础网络的时候,存储不同module即神经网络中的Linear、conv2d、Relu层等,与nn.Sequential有些类似。

ModuleList:顾名思义,专门用于存储module的list。

两者的区别(4点):

不同点1(是否自动前向传播):nn.Sequential内部实现了forward函数,因此可以不用写forward函数,而nn.ModuleList则没有实现内部forward函数。(forward函数即前向传播)

不同点2(能否重命名):nn.Sequential可以使用OrderedDict对每层进行重命名,如图model1和model2显示出来的结构前面由序号(0)(1)(2)(3)变为(conv1)(relu1)....

# Example of using Sequential
model1 = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(model1)
# Sequential(
#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (3): ReLU()
# )

# Example of using Sequential with OrderedDict
import collections
model2 = nn.Sequential(collections.OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model2)
# Sequential(
#   (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (relu1): ReLU()
#   (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (relu2): ReLU()
# )

不同点3(有无顺序):nn.Sequential里面的模块按照顺序进行排列的,上下部分有关联,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,(1)中out_feature与(2)中in_feature不同,因为没有关联

class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x

net = net3()
print(net)
# net3(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )
input = torch.randn(32, 5)
print(net(input).shape)
# torch.Size([32, 30])

不同点4:有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们

layers = [nn.Linear(10, 10) for i in range(5)]

完整的写法:

ModeleList实现方式

class net6(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
 
    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x
 
net = net6()
print(net)
# net6(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

Sequential 实现方式

net7 注意 * 这个 *** 作符,它可以把一个 list 拆开成一个个独立的元素。但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linear_list)
 
    def forward(self, x):
        self.x = self.linears(x)
        return x
 
net = net7()
print(net)
# net7(
#   (linears): Sequential(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

如下,更改self. linear设置可变成我们想要的顺序

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear = ([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
        self.linears = nn.Sequential(*self.linear)

    def forward(self, x):
        self.x = self.linears(x)
        return x

net = net7()
print(net)
# net7(
#   (linears): Sequential(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/904502.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-15
下一篇 2022-05-15

发表评论

登录后才能评论

评论列表(0条)

保存