pytorch数值溢出的几种原因

pytorch数值溢出的几种原因,第1张

pytorch数值溢出的几种原因 错误信息

一般的错误表述如下:

xxx.py:xxx: RuntimeWarning: overflow encountered in reduce
xxx.py:xxx: RuntimeWarning: invalid value encountered in true_divide
...
xxx.cu:xxx: block: [xxx,0,0], thread: [0,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
xxx.cu:xxx: block: [xxx,0,0], thread: [1,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
xxx.cu:xxx: block: [xxx,0,0], thread: [2,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
xxx.cu:xxx: block: [xxx,0,0], thread: [3,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
xxx.cu:xxx: block: [xxx,0,0], thread: [4,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
xxx.cu:xxx: block: [xxx,0,0], thread: [5,0,0] Assertion `input_val >= zero && input_val <= xxx` failed.
...
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
原因

只要看到Assertion 'input_val >= zero && input_val <= xxx' failed.就说明是数值溢出了, 溢出的位置就在 RuntimeWarning: overflow encountered in true_divide 警告出现的位置.
常见于:

input=torch.Tensor([[1, 2], [2, 1]])
target=torch.Tensor([[0, 1], [1, 0]])
loss = F.binary_cross_entropy(input, target)

如果你是在CPU上运行, 会显示:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
      1 input=torch.Tensor([[1, 2], [2, 1]])
      2 target=torch.Tensor([[0, 1], [1, 0]])
----> 3 loss = F.binary_cross_entropy(input, target)

D:Miniconda3envsdllibsite-packagestorchnnfunctional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2913         weight = weight.expand(new_size)
   2914
-> 2915     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
   2916
   2917

RuntimeError: all elements of input should be between 0 and 1

如果在GPU上运行, 会显示:

input=torch.Tensor([[1, 2], [2, 1]]).cuda()
target=torch.Tensor([[0, 1], [1, 0]]).cuda()
loss = F.binary_cross_entropy(input, target)
In [6]: input=torch.Tensor([[1, 2], [2, 1]]).cuda()
   ...: target=torch.Tensor([[0, 1], [1, 0]]).cuda()
   ...: loss = F.binary_cross_entropy(input, target)

In [7]: C:wbwindowspytorchatensrcATennativecudaLoss.cu:115: block: [0,0,0], thread: [1,0,0] Assertion `input_val >= zero && input_val <= one` failed.
C:wbwindowspytorchatensrcATennativecudaLoss.cu:115: block: [0,0,0], thread: [2,0,0] Assertion `input_val >= zero && input_val <= one` failed.

数值溢出具体有以下几个原因

1. input 没 - 1

如上图例子所示

2. 除0

常见于归一化 *** 作

# 特征归一化
mean =feature.mean((0, 2, 3), keepdims=True)
std = feature.std((0, 2, 3), keepdims=True)
feature = (feature - mean) / (std + 1e-8)

记得 + 1e-8(一个极小的数) 即可

3. sqrt(0)

亦常见于归一化 *** 作和Normal函数中

std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()

记得 + 一个极小的数 即可

std = torch.sqrt(torch.var(x, dim = 1, unbiased = False, keepdim = True) + self.eps)
4. masked_array 运算错误导致的溢出

如果用到了masked_array (numpy.ma.core.MaskedArray)
需要检查有没有对masked_array应用常规numpy *** 作:
下图显示了对masked_array应用np.concatenate导致mask失效

In [31]: data
Out[31]:
masked_array(
  data=[[0.0, 1.0, 2.0],
        [3.0, 4.0, --],
        [--, --, --]],
  mask=[[False, False, False],
        [False, False,  True],
        [ True,  True,  True]],
  fill_value=1e+20)

In [32]: np.concatenate([data, data])
Out[32]:
masked_array(
  data=[[0., 1., 2.],
        [3., 4., 0.],
        [0., 0., 0.],
        [0., 1., 2.],
        [3., 4., 0.],
        [0., 0., 0.]],
  mask=False,
  fill_value=1e+20)

乍一看好像不会导致溢出.
但是有些mask住的值为一个极大的数, 如:

In [52]: arr
Out[52]:
masked_array(
  data=[[--, --, --, --, --, --, --, --, --, --],
        [--, --, --, --, --, --, --, --, --, --],
        [--, --, --, --, --, --, --, --, --, --],
        [--, --, --, --, --, --, --, --, --, --],
        [--, --, --, --, --, --, --, --, --, --],
        [--, --, --, --, --, --, --, --, --, --],
        [-1.1238000392913818, -1.1208000183105469, -1.1162999868392944,
         -1.1130000352859497, -1.1100000143051147, -1.1064000129699707,
         -1.1026999950408936, -1.0999000072479248, -1.097599983215332,
         -1.094099998474121],
        [-1.124500036239624, -1.1211999654769897, -1.1164000034332275,
         -1.1128000020980835, -1.1095999479293823, -1.1057000160217285,
         -1.1019999980926514, -1.0995999574661255, -1.0972000360488892,
         -1.093500018119812],
        [-1.1246000528335571, -1.1208000183105469, -1.1155999898910522,
         -1.111799955368042, -1.108299970626831, -1.1043000221252441,
         -1.1003999710083008, -1.0978000164031982, -1.0953999757766724,
         -1.0923000574111938],
        [-1.1224000453948975, -1.118499994277954, -1.1129000186920166,
         -1.1088999509811401, -1.1053999662399292, -1.1014000177383423,
         -1.097499966621399, -1.0946999788284302, -1.0922000408172607,
         -1.0896999835968018]],
  mask=[[ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [False, False, False, False, False, False, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False,
         False]],
  fill_value=9.96921e+36,
  dtype=float32)

In [53]: arr.data
Out[53]:
array([[ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [ 9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36,  9.96921e+36,  9.96921e+36,
         9.96921e+36,  9.96921e+36],
       [-1.12380e+00, -1.12080e+00, -1.11630e+00, -1.11300e+00,
        -1.11000e+00, -1.10640e+00, -1.10270e+00, -1.09990e+00,
        -1.09760e+00, -1.09410e+00],
       [-1.12450e+00, -1.12120e+00, -1.11640e+00, -1.11280e+00,
        -1.10960e+00, -1.10570e+00, -1.10200e+00, -1.09960e+00,
        -1.09720e+00, -1.09350e+00],
       [-1.12460e+00, -1.12080e+00, -1.11560e+00, -1.11180e+00,
        -1.10830e+00, -1.10430e+00, -1.10040e+00, -1.09780e+00,
        -1.09540e+00, -1.09230e+00],
       [-1.12240e+00, -1.11850e+00, -1.11290e+00, -1.10890e+00,
        -1.10540e+00, -1.10140e+00, -1.09750e+00, -1.09470e+00,
        -1.09220e+00, -1.08970e+00]], dtype=float32)

In [54]: arr.sum()
Out[54]: -44.281998

In [55]: arr.data.sum()
D:Miniconda3envsdllibsite-packagesnumpycore_methods.py:47: RuntimeWarning: overflow encountered in reduce
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
Out[55]: inf

当大量9.96921e+36被泄漏出来, 并且对其执行如sum() *** 作时就会导致溢出.
检查所有对masked_array使用的 *** 作即可, 使用np.ma下的 *** 作代替np *** 作:

arr = np.ma.concatenate([arr, arr])

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5437040.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-11
下一篇 2022-12-11

发表评论

登录后才能评论

评论列表(0条)

保存