PyTorch二进制分类-相同的网络结构,“更简单”的数据,但性能较差?

PyTorch二进制分类-相同的网络结构,“更简单”的数据,但性能较差?,第1张

PyTorch二进制分类-相同的网络结构,“更简单”的数据,但性能较差? TL; DR

您的输入数据未标准化。

  1. 采用
    x_data = (x_data - x_data.mean()) / x_data.std()
  2. 提高学习率
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

你会得到

仅1000次迭代即可收敛。

更多细节

这两个示例之间的主要区别在于

x
,第一个示例中的数据以(0,0)为中心,并且方差很小。
另一方面,第二示例中的数据以92为中心,并且具有相对较大的方差。

当您随机初始化权重时,不会考虑数据中的初始偏差,该权重是基于假设输入大致呈​​正态分布在
附近的假设完成的。
优化过程几乎不可能补偿该总偏差-因此模型陷入了次优的解决方案。

将输入标准化后,通过减去平均值并除以std,优化过程将再次变得稳定,并迅速收敛为一个好的解决方案。

有关输入归一化和权重初始化的更多详细信息,您可以阅读 He等人
深入研究整流器:在ImageNet分类中超越人类水平的性能 (ICCV
2015)中的第2.2节。

如果我无法规范化数据怎么办?

如果由于某种原因您无法提前计算均值和标准数据,则仍可以

nn.BatchNorm1d
在训练过程中使用它来估计和标准化数据。例如

class Model(nn.Module):    def __init__(self, input_size, H1, output_size):        super().__init__()        self.bn = nn.BatchNorm1d(input_size)  # adding batchnorm        self.linear = nn.Linear(input_size, H1)        self.linear2 = nn.Linear(H1, output_size)    def forward(self, x):        x = torch.sigmoid(self.linear(self.bn(x)))  # batchnorm the input x        x = torch.sigmoid(self.linear2(x))        return x

这种修改 不会 对输入数据 进行 任何更改,仅在1000个纪元后就会产生类似的收敛性:

轻微评论

为了数值稳定,最好使用

nn.BCEWithLogitsLoss
代替
nn.BCELoss
。为此,您需要
torch.sigmoid
forward()
输出中删除,这
sigmoid
将在损失内进行计算。



欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5617259.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存