Pytorch中forward的调用

Pytorch中forward的调用,第1张

  • forward不是在模型创建的时候调用,而是在模型调用的时候调用
model = NeuralNet(tr_set.dataset.dim).to(device)  # Construct model and move to device

pred = model(x)  # model是一个可调用对象,即会自动调用__call__函数,这里会调用__call_impl函数,
                 # 其中的forward_call会调用NeuralNet中实现的forward,后续导致的其他调用在下面类中说明
class NeuralNet(nn.Module):
    ''' A simple fully-connected deep neural network '''

    def __init__(self, input_dim):
        super(NeuralNet, self).__init__()

        # Define your neural network here
        # TODO: How to modify this model to achieve better performance?
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        # Mean squared error loss
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, x):  # 这里forward不是override的方法,因为编译器这一行没有o图标
        ''' Given input of size (batch_size x input_dim), compute output of the network '''
        return self.net(x).squeeze(1)  
        # 这里self.net是Sequential对象,Sequential类实现了Module类,Module类实现了__call__函![请添加图片描述](http://www.kaotop.com/file/tupian/20220510/b08b20342690454a88761183bdb4b07a.png)
数,故Module、Sequential实例都是可调用的 从module.py的forward_call(*input, **kwargs)调用过来
        # 这里self.net(x)会调用__call_impl,(self是Sequential类型,会自动调用__call_impl)从而会导致Sequential中的forward被调用,
        # 在Sequential的forward中对每个module(这里指Linear、RELU、Liner)循环调用,因此分别会调用Liner、RELU、Linear的__call_impl,从而它们各自的forward会被调用
        # 这里impl都是一个,但对不同的类通过impl中的forward_call会分别调用每个类自己的forward
    def cal_loss(self, pred, target):
        ''' Calculate loss '''
        # TODO: you may implement L1/L2 regularization here
        return self.criterion(pred, target)

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/883955.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-14
下一篇 2022-05-14

发表评论

登录后才能评论

评论列表(0条)

保存