打印Show Attend and Tell的损失函数

打印Show Attend and Tell的损失函数,第1张

打印Show Attend and Tell的损失函数 打印Show Attend and Tell的损失函数
criterion = nn.CrossEntropyLoss().to(device)
print(criterion)

输出

CrossEntropyLoss()#交叉熵损失函数
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
print(normalize)

输出:

Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
for epoch in range(start_epoch, epochs):
    print(start_epoch)
    print(epochs)
    break

输出:

0
2
datasets.py文件中class CaptionDataset(Dataset)查看captions
data_folder = 'image data/dataset_Flickr8k/'
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'
print(os.path.join(data_folder, 'TRAIN' + '_CAPTIONS_' + data_name + '.json')
with open(os.path.join(data_folder, 'TRAIN' + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
	captions = json.load(j)
	print(captions)

输出:

/***/image data/dataset_Flickr8k/TRAIN_CAPTIONS_flickr8k_5_cap_per_img_5_min_word_freq.json
[
	...
	[2631, 198, 196, 165, 44, 1, 36, 288, 197, 104, 198, 2630, 133, 428, 408, 702, 2630, 38, 1476, 38, 198, 274, 87, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 1, 2630, 44, 1, 256, 36, 288, 197, 104, 1, 428, 702, 2630, 38, 1, 89, 8, 1, 2, 56, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 9, 32, 8, 9, 733, 734, 4, 2630, 8, 192, 33, 34, 9, 91, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 1, 29, 32, 8, 1, 733, 734, 721, 168, 8, 9, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 1, 120, 32, 8, 1, 544, 734, 140, 155, 1, 26, 27, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 1, 120, 32, 368, 162, 27, 9, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
	[2631, 1, 29, 32, 64, 1, 2, 2630, 4, 368, 8, 9, 78, 42, 28, 8, 9, 366, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
datasets.py文件中class CaptionDataset(Dataset)查看caplens
data_folder = 'image data/dataset_Flickr8k/'
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'
print(os.path.join(data_folder, 'TRAIN' + '_CAPLENS_' + data_name + '.json'))
with open(os.path.join(data_folder, 'TRAIN' + '_CAPLENS_' + data_name + '.json'), 'r') as j:
	caplens = json.load(j)
	print(caplens)

输出:

/***/image data/dataset_Flickr8k/TRAIN_CAPLENS_flickr8k_5_cap_per_img_5_min_word_freq.json
[...14, 10, 16, 11, 12, 8, 14, 8, 15, 9, 18, 15, 14, 24, 22, 17, 14, 15, 10, 19]
查看输入encoder的图片输出结果(很重要)

resnet101对一个batch的图片进行编码,输出的张量为torch.Size([64, 14, 14, 2048])

1个batch大小为64
通道数为14
高为14
宽为2048
for i, (imgs, caps, caplens) in enumerate(train_loader):
     data_time.update(time.time() - start)
     # print(data_time)# 
     # Move to GPU, if available
     imgs = imgs.to(device)
     caps = caps.to(device) # [2631, 1, 2589, 4, 1330, 1, 10, 365, 211, 408, 856, 34, 1, 543, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
     caplens = caplens.to(device)
     imgs = encoder(imgs)
     print(imgs)
     break

输出:

torch.Size([64, 14, 14, 2048])
tensor(
[
	[
		[[9.0361e-02, 5.2851e-01, 2.1546e+00,  ..., 0.0000e+00, 0.0000e+00, 3.5757e-01],
          [2.1035e-01, 6.9238e-01, 2.0149e+00,  ..., 0.0000e+00, 9.5086e-01, 2.3448e-01],
          [3.3035e-01, 8.5626e-01, 1.8753e+00,  ..., 0.0000e+00, 1.9017e+00, 1.1140e-01],
          ...,
          [0.0000e+00, 2.1303e-01, 3.7578e-01,  ..., 1.7350e-02, 6.0041e-02, 3.7648e-01],
          [0.0000e+00, 1.9494e-01, 2.9860e-01,  ..., 8.6750e-03, 3.0021e-02, 4.0226e-01],
          [0.0000e+00, 1.7685e-01, 2.2142e-01,  ..., 0.0000e+00, 0.0000e+00, 4.2803e-01]
        ],

        [[4.5180e-02, 8.1279e-01, 1.0773e+00,  ..., 1.8461e-01, 8.3581e-01, 1.7878e-01],
          [1.0518e-01, 9.8500e-01, 1.0075e+00,  ..., 9.2305e-02, 1.7277e+00, 1.1724e-01],
          [1.6517e-01, 1.1572e+00, 9.3763e-01,  ..., 0.0000e+00, 2.6196e+00, 5.5702e-02],
          ...,
          [0.0000e+00, 5.1433e-01, 1.8789e-01,  ..., 8.6750e-03, 1.1592e-01, 1.8824e-01],
          [0.0000e+00, 5.3312e-01, 1.4930e-01,  ..., 4.3375e-03, 5.7961e-02, 2.0113e-01],
          [0.0000e+00, 5.5192e-01, 1.1071e-01,  ..., 0.0000e+00, 0.0000e+00, 2.1402e-01]
        ],

        [[0.0000e+00, 1.0971e+00, 0.0000e+00,  ..., 3.6922e-01,1.6716e+00, 0.0000e+00],
          [0.0000e+00, 1.2776e+00, 0.0000e+00,  ..., 1.8461e-01,2.5045e+00, 0.0000e+00],
          [0.0000e+00, 1.4582e+00, 0.0000e+00,  ..., 0.0000e+00,3.3375e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 8.1563e-01, 0.0000e+00,  ..., 0.0000e+00,1.7180e-01, 0.0000e+00],
          [0.0000e+00, 8.7131e-01, 0.0000e+00,  ..., 0.0000e+00,8.5902e-02, 0.0000e+00],
          [0.0000e+00, 9.2699e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00]
        ],

         ...,

        [[0.0000e+00, 1.1359e+00, 0.0000e+00,  ..., 2.0694e-01,1.1926e+00, 9.9706e-01],
          [0.0000e+00, 1.1718e+00, 0.0000e+00,  ..., 1.8501e-01,1.7524e+00, 9.7201e-01],
          [0.0000e+00, 1.2076e+00, 0.0000e+00,  ..., 1.6309e-01,2.3122e+00, 9.4696e-01],
          ...,
          [0.0000e+00, 4.0909e-01, 0.0000e+00,  ..., 2.9820e-02,8.7141e-01, 0.0000e+00],
          [0.0000e+00, 5.2929e-01, 1.7152e-01,  ..., 1.4910e-02,8.3648e-01, 0.0000e+00],
          [0.0000e+00, 6.4950e-01, 3.4304e-01,  ..., 0.0000e+00,8.0156e-01, 0.0000e+00]
        ],

        [[0.0000e+00, 1.0725e+00, 0.0000e+00,  ..., 1.0347e-01,9.2755e-01, 7.7783e-01],
          [0.0000e+00, 8.8376e-01, 0.0000e+00,  ..., 9.2507e-02,1.2014e+00, 7.1052e-01],
          [0.0000e+00, 6.9500e-01, 0.0000e+00,  ..., 8.1545e-02,1.4753e+00, 6.4321e-01],
          ...,
          [0.0000e+00, 6.4680e-01, 2.7190e-01,  ..., 1.4910e-02,6.5424e-01, 0.0000e+00],
          [0.0000e+00, 6.4789e-01, 5.2873e-01,  ..., 7.4550e-03,7.9419e-01, 0.0000e+00],
          [0.0000e+00, 6.4898e-01, 7.8556e-01,  ..., 0.0000e+00,9.3414e-01, 0.0000e+00]
        ],

        [[0.0000e+00, 1.0091e+00, 0.0000e+00,  ..., 0.0000e+00,6.6249e-01, 5.5859e-01],
          [0.0000e+00, 5.9575e-01, 0.0000e+00,  ..., 0.0000e+00,6.5037e-01, 4.4903e-01],
          [0.0000e+00, 1.8237e-01, 0.0000e+00,  ..., 0.0000e+00,6.3826e-01, 3.3946e-01],
          ...,
          [0.0000e+00, 8.8452e-01, 5.4380e-01,  ..., 0.0000e+00,4.3707e-01, 0.0000e+00],
          [0.0000e+00, 7.6649e-01, 8.8594e-01,  ..., 0.0000e+00,7.5190e-01, 0.0000e+00],
          [0.0000e+00, 6.4845e-01, 1.2281e+00,  ..., 0.0000e+00,1.0667e+00, 0.0000e+00]
        ]
	],


    [
   		[[3.1328e-01, 4.3026e-01, 1.2069e+00,  ..., 3.4335e-01,0.0000e+00, 1.6301e-01],
          [1.5664e-01, 7.5903e-01, 1.1955e+00,  ..., 3.6088e-01,0.0000e+00, 2.1009e-01],
          [0.0000e+00, 1.0878e+00, 1.1842e+00,  ..., 3.7841e-01,0.0000e+00, 2.5717e-01],
          ...,
          [0.0000e+00, 9.5698e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.5350e+00],
          [0.0000e+00, 1.1680e+00, 6.2503e-01,  ..., 0.0000e+00,0.0000e+00, 1.5504e+00],
          [0.0000e+00, 1.3790e+00, 1.2501e+00,  ..., 0.0000e+00,0.0000e+00, 1.5659e+00]
        ],

         [[1.5664e-01, 5.1228e-01, 2.1516e+00,  ..., 1.7679e-01,9.3442e-02, 1.5183e-01],
          [7.8320e-02, 6.2838e-01, 1.9712e+00,  ..., 1.9270e-01,4.6721e-02, 1.9961e-01],
          [0.0000e+00, 7.4448e-01, 1.7908e+00,  ..., 2.0861e-01,0.0000e+00, 2.4739e-01],
          ...,
          [3.5671e-02, 1.1973e+00, 4.5830e-01,  ..., 0.0000e+00,0.0000e+00, 1.3950e+00],
          [1.0273e-01, 1.1930e+00, 7.2625e-01,  ..., 0.0000e+00,0.0000e+00, 1.3092e+00],
          [1.6979e-01, 1.1886e+00, 9.9420e-01,  ..., 0.0000e+00,0.0000e+00, 1.2233e+00]],

         [[0.0000e+00, 5.9430e-01, 3.0962e+00,  ..., 1.0227e-02,1.8688e-01, 1.4064e-01],
          [0.0000e+00, 4.9773e-01, 2.7468e+00,  ..., 2.4522e-02,9.3442e-02, 1.8913e-01],
          [0.0000e+00, 4.0116e-01, 2.3975e+00,  ..., 3.8816e-02,0.0000e+00, 2.3761e-01],
          ...,
          [7.1341e-02, 1.4377e+00, 9.1659e-01,  ..., 0.0000e+00,0.0000e+00, 1.2551e+00],
          [2.0546e-01, 1.2180e+00, 8.2746e-01,  ..., 0.0000e+00,0.0000e+00, 1.0679e+00],
          [3.3958e-01, 9.9825e-01, 7.3833e-01,  ..., 0.0000e+00,0.0000e+00, 8.8069e-01]],

         ...,

         [[2.1133e-01, 6.3768e-01, 2.9080e+00,  ..., 0.0000e+00,4.4643e-01, 7.4818e-01],
          [2.8350e-01, 6.8930e-01, 2.8615e+00,  ..., 0.0000e+00,3.0719e-01, 7.3427e-01],
          [3.5566e-01, 7.4091e-01, 2.8149e+00,  ..., 0.0000e+00,1.6795e-01, 7.2036e-01],
          ...,
          [1.4313e+00, 1.8079e+00, 1.1452e+00,  ..., 0.0000e+00,2.4840e-03, 8.6259e-01],
          [1.0502e+00, 1.9889e+00, 1.4327e+00,  ..., 0.0000e+00,6.1725e-02, 1.2689e+00],
          [6.6908e-01, 2.1700e+00, 1.7202e+00,  ..., 0.0000e+00,1.2097e-01, 1.6752e+00]],

         [[1.5723e-01, 1.1452e+00, 3.1622e+00,  ..., 1.5288e-01,2.2322e-01, 5.7149e-01],
          [1.6753e-01, 1.0057e+00, 3.2275e+00,  ..., 9.2347e-02,2.3731e-01, 6.1787e-01],
          [1.7783e-01, 8.6625e-01, 3.2928e+00,  ..., 3.1819e-02,2.5141e-01, 6.6425e-01],
          ...,
          [1.0436e+00, 1.6407e+00, 1.6283e+00,  ..., 0.0000e+00,1.0191e-01, 8.7368e-01],
          [7.4215e-01, 1.8037e+00, 1.7287e+00,  ..., 0.0000e+00,1.2338e-01, 1.1357e+00],
          [4.4071e-01, 1.9668e+00, 1.8292e+00,  ..., 0.0000e+00,1.4485e-01, 1.3978e+00]],

         [[1.0312e-01, 1.6527e+00, 3.4163e+00,  ..., 3.0575e-01,0.0000e+00, 3.9481e-01],
          [5.1561e-02, 1.3221e+00, 3.5935e+00,  ..., 1.8469e-01,1.6743e-01, 5.0147e-01],
          [0.0000e+00, 9.9158e-01, 3.7707e+00,  ..., 6.3637e-02,3.3487e-01, 6.0814e-01],
          ...,
          [6.5591e-01, 1.4736e+00, 2.1113e+00,  ..., 0.0000e+00,2.0134e-01, 8.8477e-01],
          [4.3412e-01, 1.6185e+00, 2.0248e+00,  ..., 0.0000e+00,1.8504e-01, 1.0025e+00],
          [2.1233e-01, 1.7635e+00, 1.9382e+00,  ..., 0.0000e+00,1.6873e-01, 1.1203e+00]
        ]
	],


    [
    	[[0.0000e+00, 0.0000e+00, 9.5545e-01,  ..., 8.0960e-01,0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 8.0015e-01,  ..., 5.5233e-01,0.0000e+00, 7.4726e-03],
          [0.0000e+00, 0.0000e+00, 6.4485e-01,  ..., 2.9505e-01,0.0000e+00, 1.4945e-02],
          ...,
          [0.0000e+00, 3.3634e-01, 8.3794e-01,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.4318e-01, 8.7203e-01,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.5001e-01, 9.0612e-01,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 5.6321e-01,  ..., 6.7477e-01,0.0000e+00, 9.3554e-04],
          [0.0000e+00, 0.0000e+00, 4.7451e-01,  ..., 5.0247e-01,0.0000e+00, 3.0180e-02],
          [0.0000e+00, 0.0000e+00, 3.8581e-01,  ..., 3.3016e-01,0.0000e+00, 5.9425e-02],
          ...,
          [0.0000e+00, 6.7129e-01, 1.1972e+00,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.5759e-01, 1.6487e+00,  ..., 0.0000e+00,1.0493e-02, 0.0000e+00],
          [0.0000e+00, 4.4389e-01, 2.1002e+00,  ..., 0.0000e+00,2.0986e-02, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 1.7097e-01,  ..., 5.3994e-01,0.0000e+00, 1.8711e-03],
          [0.0000e+00, 0.0000e+00, 1.4887e-01,  ..., 4.5260e-01,0.0000e+00, 5.2887e-02],
          [0.0000e+00, 0.0000e+00, 1.2677e-01,  ..., 3.6527e-01,0.0000e+00, 1.0390e-01],
          ...,
          [0.0000e+00, 1.0062e+00, 1.5565e+00,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00],
          [0.0000e+00, 7.7201e-01, 2.4254e+00,  ..., 0.0000e+00,2.0986e-02, 0.0000e+00],
          [0.0000e+00, 5.3778e-01, 3.2942e+00,  ..., 0.0000e+00,4.1972e-02, 0.0000e+00]],

         ...,

         [[1.7784e-03, 6.2061e-02, 4.6169e-01,  ..., 3.2132e-01,0.0000e+00, 1.1172e-02],
          [1.2874e-01, 9.7284e-02, 2.3084e-01,  ..., 3.0055e-01,0.0000e+00, 6.3651e-02],
          [2.5569e-01, 1.3251e-01, 0.0000e+00,  ..., 2.7979e-01,0.0000e+00, 1.1613e-01],
          ...,
          [5.3011e-01, 1.0567e-01, 0.0000e+00,  ..., 4.1936e-01,5.7577e-02, 1.2729e-01],
          [4.6102e-01, 1.8947e-01, 1.7549e-01,  ..., 3.3047e-01,9.7484e-02, 2.2460e-01],
          [3.9193e-01, 2.7327e-01, 3.5098e-01,  ..., 2.4158e-01,1.3739e-01, 3.2192e-01]],

         [[1.8239e-01, 9.0274e-02, 1.1396e+00,  ..., 3.3605e-01,0.0000e+00, 5.5860e-03],
          [2.7868e-01, 1.1954e-01, 8.2689e-01,  ..., 3.8124e-01,0.0000e+00, 1.0846e-01],
          [3.7498e-01, 1.4880e-01, 5.1420e-01,  ..., 4.2642e-01,0.0000e+00, 2.1132e-01],
          ...,
          [4.7168e-01, 2.4409e-01, 4.4172e-01,  ..., 4.4215e-01,9.2666e-02, 2.3359e-01],
          [4.8129e-01, 3.6893e-01, 6.0175e-01,  ..., 3.4435e-01,1.7973e-01, 2.9912e-01],
          [4.9090e-01, 4.9378e-01, 7.6177e-01,  ..., 2.4655e-01,2.6680e-01, 3.6464e-01]],

         [[3.6300e-01, 1.1849e-01, 1.8174e+00,  ..., 3.5079e-01,0.0000e+00, 0.0000e+00],
          [4.2863e-01, 1.4179e-01, 1.4229e+00,  ..., 4.6192e-01,0.0000e+00, 1.5326e-01],
          [4.9426e-01, 1.6509e-01, 1.0284e+00,  ..., 5.7305e-01,0.0000e+00, 3.0652e-01],
          ...,
          [4.1326e-01, 3.8251e-01, 8.8344e-01,  ..., 4.6493e-01,1.2776e-01, 3.3989e-01],
          [5.0157e-01, 5.4840e-01, 1.0280e+00,  ..., 3.5823e-01,2.6198e-01, 3.7363e-01],
          [5.8988e-01, 7.1429e-01, 1.1726e+00,  ..., 2.5153e-01,3.9621e-01, 4.0737e-01]
         ]
 	],


        ...,


        [[[2.4759e-01, 0.0000e+00, 0.0000e+00,  ..., 4.0495e-02,1.3125e-01, 0.0000e+00],
          [2.2211e-01, 0.0000e+00, 0.0000e+00,  ..., 2.0248e-02,2.8641e-01, 2.7549e-02],
          [1.9663e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,4.4156e-01, 5.5099e-02],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,4.2970e-02, 1.7396e-01],
          [0.0000e+00, 9.1169e-02, 0.0000e+00,  ..., 0.0000e+00,2.1485e-02, 1.5837e-01],
          [0.0000e+00, 1.8234e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.4279e-01]
         ],

         [[1.4166e-01, 0.0000e+00, 0.0000e+00,  ..., 7.0862e-02,6.5624e-02, 0.0000e+00],
          [1.3300e-01, 0.0000e+00, 0.0000e+00,  ..., 1.2258e-01,2.0074e-01, 1.3775e-02],
          [1.2434e-01, 0.0000e+00, 0.0000e+00,  ..., 1.7431e-01,3.3585e-01, 2.7549e-02],
          ...,
          [0.0000e+00, 9.8729e-02, 0.0000e+00,  ..., 0.0000e+00,3.9813e-01, 2.7212e-01],
          [0.0000e+00, 1.9067e-01, 0.0000e+00,  ..., 0.0000e+00,1.9907e-01, 1.7175e-01],
          [0.0000e+00, 2.8260e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 7.1393e-02]
         ],

         [[3.5734e-02, 0.0000e+00, 0.0000e+00,  ..., 1.0123e-01,0.0000e+00, 0.0000e+00],
          [4.3895e-02, 0.0000e+00, 0.0000e+00,  ..., 2.2492e-01,1.1507e-01, 0.0000e+00],
          [5.2056e-02, 0.0000e+00, 0.0000e+00,  ..., 3.4861e-01,2.3014e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 1.9746e-01, 0.0000e+00,  ..., 0.0000e+00,7.5330e-01, 3.7028e-01],
          [0.0000e+00, 2.9016e-01, 0.0000e+00,  ..., 0.0000e+00,3.7665e-01, 1.8514e-01],
          [0.0000e+00, 3.8287e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00]
         ],

         ...,

         [[4.6125e-02, 3.5640e-01, 2.0939e+00,  ..., 0.0000e+00,0.0000e+00, 5.0760e-02],
          [2.3063e-02, 3.1719e-01, 1.5483e+00,  ..., 1.7244e-02,0.0000e+00, 2.5380e-02],
          [0.0000e+00, 2.7799e-01, 1.0028e+00,  ..., 3.4487e-02,0.0000e+00, 0.0000e+00],
          ...,
          [7.0788e-01, 2.2093e-01, 0.0000e+00,  ..., 3.0831e-01,0.0000e+00, 1.2356e+00],
          [1.1000e+00, 1.8086e-01, 0.0000e+00,  ..., 3.0003e-01,0.0000e+00, 1.3106e+00],
          [1.4922e+00, 1.4078e-01, 0.0000e+00,  ..., 2.9175e-01,0.0000e+00, 1.3855e+00]
         ],

         [[2.3063e-02, 3.1117e-01, 1.5658e+00,  ..., 0.0000e+00,1.3866e-01, 2.5380e-02],
          [1.1531e-02, 2.9614e-01, 1.3523e+00,  ..., 8.6218e-03,6.9332e-02, 1.2690e-02],
          [0.0000e+00, 2.8111e-01, 1.1389e+00,  ..., 1.7244e-02,0.0000e+00, 0.0000e+00],
          ...,
          [6.4155e-01, 1.1047e-01, 0.0000e+00,  ..., 2.0347e-01,0.0000e+00, 9.2113e-01],
          [8.8613e-01, 9.0429e-02, 0.0000e+00,  ..., 2.4763e-01,0.0000e+00, 1.0280e+00],
          [1.1307e+00, 7.0391e-02, 0.0000e+00,  ..., 2.9179e-01,0.0000e+00, 1.1349e+00]
         ],

         [[0.0000e+00, 2.6594e-01, 1.0377e+00,  ..., 0.0000e+00,2.7733e-01, 0.0000e+00],
          [0.0000e+00, 2.7508e-01, 1.1563e+00,  ..., 0.0000e+00,1.3866e-01, 0.0000e+00],
          [0.0000e+00, 2.8422e-01, 1.2750e+00,  ..., 0.0000e+00,0.0000e+00, 0.0000e+00],
          ...,
          [5.7521e-01, 0.0000e+00, 0.0000e+00,  ..., 9.8626e-02,0.0000e+00, 6.0663e-01],
          [6.7222e-01, 0.0000e+00, 0.0000e+00,  ..., 1.9523e-01,0.0000e+00, 7.4552e-01],
          [7.6922e-01, 0.0000e+00, 0.0000e+00,  ..., 2.9183e-01,0.0000e+00, 8.8441e-01]
         ]
	],


        [[[3.8767e-02, 0.0000e+00, 0.0000e+00,  ..., 2.2975e-03,5.7446e-01, 0.0000e+00],
          [4.2504e-02, 7.7990e-02, 0.0000e+00,  ..., 2.9087e-02,4.9452e-01, 0.0000e+00],
          [4.6241e-02, 1.5598e-01, 0.0000e+00,  ..., 5.5877e-02,4.1458e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 5.9729e-01, 1.7794e+00,  ..., 4.1452e-01,5.4988e-01, 1.0653e+00],
          [0.0000e+00, 4.1452e-01, 1.5713e+00,  ..., 4.3931e-01,6.4155e-01, 1.1801e+00],
          [0.0000e+00, 2.3176e-01, 1.3632e+00,  ..., 4.6411e-01,7.3321e-01, 1.2950e+00]
         ],

         [[2.3761e-01, 0.0000e+00, 0.0000e+00,  ..., 1.1487e-03,4.0259e-01, 0.0000e+00],
          [1.8682e-01, 7.9244e-02, 0.0000e+00,  ..., 1.4544e-02,3.6260e-01, 0.0000e+00],
          [1.3603e-01, 1.5849e-01, 0.0000e+00,  ..., 2.7939e-02,3.2261e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 1.1255e+00, 8.8969e-01,  ..., 7.6367e-01,9.6885e-01, 1.3478e+00],
          [3.4041e-02, 7.2316e-01, 8.3901e-01,  ..., 5.3454e-01,1.0559e+00, 1.2100e+00],
          [6.8081e-02, 3.2080e-01, 7.8834e-01,  ..., 3.0541e-01,1.1429e+00, 1.0722e+00]
         ],

         [[4.3645e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,2.3072e-01, 0.0000e+00],
          [3.3113e-01, 8.0499e-02, 0.0000e+00,  ..., 0.0000e+00,2.3068e-01, 0.0000e+00],
          [2.2582e-01, 1.6100e-01, 0.0000e+00,  ..., 0.0000e+00,2.3065e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 1.6537e+00, 0.0000e+00,  ..., 1.1128e+00,1.3878e+00, 1.6304e+00],
          [6.8081e-02, 1.0318e+00, 1.0673e-01,  ..., 6.2977e-01,1.4702e+00, 1.2399e+00],
          [1.3616e-01, 4.0984e-01, 2.1347e-01,  ..., 1.4671e-01,1.5525e+00, 8.4945e-01]
         ],

         ...,

         [[0.0000e+00, 2.3937e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.6858e-01],
          [0.0000e+00, 1.5298e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.7305e-01],
          [0.0000e+00, 6.6594e-01, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.7752e-01],
          ...,
          [0.0000e+00, 1.4648e+00, 0.0000e+00,  ..., 1.5091e+00,0.0000e+00, 4.2158e-02],
          [0.0000e+00, 1.6270e+00, 0.0000e+00,  ..., 1.3410e+00,0.0000e+00, 9.1624e-02],
          [0.0000e+00, 1.7892e+00, 0.0000e+00,  ..., 1.1730e+00,0.0000e+00, 1.4109e-01]],

         [[1.3456e-02, 2.5342e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 9.5133e-02],
          [6.7281e-03, 1.8158e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 1.9829e-01],
          [0.0000e+00, 1.0975e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 3.0144e-01],
          ...,
          [0.0000e+00, 1.6274e+00, 8.1445e-02,  ..., 1.5885e+00,1.9944e-01, 2.7497e-02],
          [0.0000e+00, 1.9397e+00, 4.0722e-02,  ..., 1.2676e+00,1.7066e-01, 7.8321e-02],
          [0.0000e+00, 2.2520e+00, 0.0000e+00,  ..., 9.4676e-01,1.4188e-01, 1.2914e-01]],

         [[2.6912e-02, 2.6746e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 2.1684e-02],
          [1.3456e-02, 2.1018e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 2.2353e-01],
          [0.0000e+00, 1.5291e+00, 0.0000e+00,  ..., 0.0000e+00,0.0000e+00, 4.2537e-01],
          ...,
          [0.0000e+00, 1.7899e+00, 1.6289e-01,  ..., 1.6679e+00,3.9889e-01, 1.2835e-02],
          [0.0000e+00, 2.2524e+00, 8.1445e-02,  ..., 1.1942e+00,3.4133e-01, 6.5017e-02],
          [0.0000e+00, 2.7149e+00, 0.0000e+00,  ..., 7.2052e-01,2.8377e-01, 1.1720e-01]
         ]
	],


        [[[0.0000e+00, 0.0000e+00, 4.0027e-01,  ..., 4.8410e-01,
           0.0000e+00, 2.1529e-02],
          [2.0282e-01, 0.0000e+00, 7.5335e-01,  ..., 5.0351e-01,
           0.0000e+00, 6.1349e-02],
          [4.0564e-01, 0.0000e+00, 1.1064e+00,  ..., 5.2291e-01,
           0.0000e+00, 1.0117e-01],
          ...,
          [2.1032e-01, 0.0000e+00, 0.0000e+00,  ..., 5.1751e-01,
           0.0000e+00, 0.0000e+00],
          [1.0516e-01, 0.0000e+00, 1.3890e-01,  ..., 2.5875e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.7780e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[2.7832e-01, 0.0000e+00, 2.0013e-01,  ..., 5.3671e-01,
           0.0000e+00, 1.6649e-01],
          [4.1329e-01, 5.7453e-02, 3.7668e-01,  ..., 8.4852e-01,
           0.0000e+00, 2.3294e-01],
          [5.4827e-01, 1.1491e-01, 5.5322e-01,  ..., 1.1603e+00,
           0.0000e+00, 2.9940e-01],
          ...,
          [1.0516e-01, 0.0000e+00, 1.3719e-03,  ..., 5.2783e-01,
           0.0000e+00, 6.0100e-02],
          [5.2580e-02, 0.0000e+00, 8.5612e-02,  ..., 2.6392e-01,
           0.0000e+00, 3.0050e-02],
          [0.0000e+00, 0.0000e+00, 1.6985e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[5.5663e-01, 0.0000e+00, 0.0000e+00,  ..., 5.8931e-01,
           0.0000e+00, 3.1144e-01],
          [6.2376e-01, 1.1491e-01, 0.0000e+00,  ..., 1.1935e+00,
           0.0000e+00, 4.0454e-01],
          [6.9089e-01, 2.2981e-01, 0.0000e+00,  ..., 1.7977e+00,
           0.0000e+00, 4.9763e-01],
          ...,
          [0.0000e+00, 0.0000e+00, 2.7439e-03,  ..., 5.3816e-01,
           0.0000e+00, 1.2020e-01],
          [0.0000e+00, 0.0000e+00, 3.2324e-02,  ..., 2.6908e-01,
           0.0000e+00, 6.0100e-02],
          [0.0000e+00, 0.0000e+00, 6.1904e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[5.9512e-01, 0.0000e+00, 3.8417e-01,  ..., 0.0000e+00,
           0.0000e+00, 9.0328e-01],
          [7.4327e-01, 0.0000e+00, 2.3633e-01,  ..., 1.6930e-02,
           2.4616e-01, 8.2458e-01],
          [8.9143e-01, 0.0000e+00, 8.8485e-02,  ..., 3.3860e-02,
           4.9232e-01, 7.4588e-01],
          ...,
          [1.3619e+00, 7.0660e-01, 0.0000e+00,  ..., 4.8366e-01,
           4.6314e-02, 1.4390e+00],
          [8.5559e-01, 7.8314e-01, 0.0000e+00,  ..., 4.2197e-01,
           2.3157e-02, 1.1281e+00],
          [3.4924e-01, 8.5967e-01, 0.0000e+00,  ..., 3.6027e-01,
           0.0000e+00, 8.1720e-01]],

         [[6.1491e-01, 0.0000e+00, 7.3014e-01,  ..., 4.3681e-02,
           0.0000e+00, 1.0342e+00],
          [7.4444e-01, 0.0000e+00, 5.7319e-01,  ..., 1.5873e-01,
           1.2308e-01, 9.3152e-01],
          [8.7397e-01, 0.0000e+00, 4.1624e-01,  ..., 2.7378e-01,
           2.4616e-01, 8.2882e-01],
          ...,
          [1.0826e+00, 8.5710e-01, 0.0000e+00,  ..., 5.1503e-01,
           7.0558e-02, 1.3325e+00],
          [6.7622e-01, 8.8606e-01, 0.0000e+00,  ..., 4.5536e-01,
           1.0796e-01, 1.0618e+00],
          [2.6988e-01, 9.1502e-01, 0.0000e+00,  ..., 3.9569e-01,
           1.4537e-01, 7.9097e-01]],

         [[6.3469e-01, 0.0000e+00, 1.0761e+00,  ..., 8.7361e-02,
           0.0000e+00, 1.1651e+00],
          [7.4560e-01, 0.0000e+00, 9.1006e-01,  ..., 3.0053e-01,
           0.0000e+00, 1.0385e+00],
          [8.5651e-01, 0.0000e+00, 7.4400e-01,  ..., 5.1370e-01,
           0.0000e+00, 9.1176e-01],
          ...,
          [8.0316e-01, 1.0076e+00, 0.0000e+00,  ..., 5.4639e-01,
           9.4802e-02, 1.2261e+00],
          [4.9684e-01, 9.8899e-01, 0.0000e+00,  ..., 4.8875e-01,
           1.9277e-01, 9.9542e-01],
          [1.9051e-01, 9.7037e-01, 0.0000e+00,  ..., 4.3111e-01,
           2.9074e-01, 7.6474e-01]]]], device='cuda:0')
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
print('scores',scores)
print('caps_sorted',caps_sorted)
print('decode_lengths',decode_lengths)
print('alphas',alphas)
print('sort_ind',sort_ind)

输出:

scores 
	tensor([[[ 0.1060,  0.0381, -0.0553,  ...,  0.1050,  0.0229, -0.3909],
         [-0.2796,  0.0837, -0.1078,  ..., -0.0661, -0.0108, -0.0798],
         [-0.1874, -0.1624,  0.2605,  ..., -0.0012, -0.1666,  0.0373],
         ...,
         [-0.2718,  0.3011,  0.0513,  ...,  0.0440, -0.0508, -0.1148],
         [ 0.2318,  0.0158, -0.1945,  ...,  0.2243, -0.2355,  0.0454],
         [-0.2298, -0.2841, -0.0364,  ..., -0.2020,  0.0268, -0.2004]],

        [[ 0.1225, -0.0296, -0.0828,  ..., -0.0485,  0.0317, -0.0447],
         [ 0.0185, -0.0447, -0.0942,  ..., -0.3165, -0.0886,  0.1155],
         [ 0.0655, -0.0573,  0.1095,  ..., -0.3358, -0.1947, -0.2325],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.2505,  0.0975, -0.1628,  ..., -0.3082, -0.0931, -0.3930],
         [-0.3288, -0.0160, -0.1472,  ..., -0.4279, -0.2406, -0.0976],
         [-0.2872, -0.0300,  0.1982,  ..., -0.4447, -0.2396, -0.4520],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        ...,

        [[ 0.3742,  0.3620, -0.2312,  ...,  0.2165, -0.2903, -0.1331],
         [-0.0943,  0.2484, -0.2200,  ..., -0.1846, -0.5298, -0.3780],
         [-0.1539, -0.1260,  0.0355,  ..., -0.5587, -0.3220, -0.3822],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0985, -0.0282,  0.0497,  ..., -0.3646, -0.0202,  0.0497],
         [-0.1268,  0.1090, -0.0767,  ..., -0.2165, -0.0731, -0.1414],
         [-0.2086,  0.0095,  0.0765,  ..., -0.3952, -0.2807, -0.0504],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0374,  0.2859, -0.1097,  ...,  0.3040, -0.4128, -0.3252],
         [ 0.0172,  0.1231, -0.1188,  ..., -0.1551, -0.1775, -0.1338],
         [-0.1317,  0.0389,  0.0262,  ..., -0.1380, -0.0653, -0.2756],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0', grad_fn=)
       
caps_sorted 
	tensor([[2631,    1,   87,  ...,    0,    0,    0],
        [2631,  198,  200,  ...,    0,    0,    0],
        [2631,    1,   99,  ...,    0,    0,    0],
        ...,
        [2631,  262, 1028,  ...,    0,    0,    0],
        [2631,   14,   15,  ...,    0,    0,    0],
        [2631,   99, 1114,  ...,    0,    0,    0]], device='cuda:0')
        
decode_lengths 
	[26, 20, 20, 19, 18, 17, 16, 16, 16, 15, 15, 15, 15, 14, 14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10, 10, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 8, 8, 8, 7, 7, 6]
	
alphas 
	tensor([[[0.0047, 0.0045, 0.0045,  ..., 0.0050, 0.0048, 0.0046],
         [0.0048, 0.0048, 0.0048,  ..., 0.0050, 0.0048, 0.0045],
         [0.0049, 0.0049, 0.0048,  ..., 0.0050, 0.0048, 0.0046],
         ...,
         [0.0049, 0.0049, 0.0048,  ..., 0.0049, 0.0048, 0.0046],
         [0.0050, 0.0049, 0.0048,  ..., 0.0049, 0.0048, 0.0046],
         [0.0050, 0.0049, 0.0048,  ..., 0.0049, 0.0048, 0.0046]],

        [[0.0051, 0.0050, 0.0049,  ..., 0.0051, 0.0049, 0.0046],
         [0.0051, 0.0049, 0.0049,  ..., 0.0049, 0.0049, 0.0047],
         [0.0050, 0.0049, 0.0049,  ..., 0.0049, 0.0049, 0.0046],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0046, 0.0048, 0.0048,  ..., 0.0050, 0.0047, 0.0045],
         [0.0049, 0.0050, 0.0049,  ..., 0.0048, 0.0046, 0.0045],
         [0.0048, 0.0049, 0.0049,  ..., 0.0048, 0.0047, 0.0045],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        ...,

        [[0.0046, 0.0046, 0.0046,  ..., 0.0053, 0.0053, 0.0052],
         [0.0043, 0.0044, 0.0043,  ..., 0.0052, 0.0052, 0.0052],
         [0.0043, 0.0044, 0.0043,  ..., 0.0052, 0.0052, 0.0052],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0045, 0.0045, 0.0046,  ..., 0.0051, 0.0055, 0.0057],
         [0.0046, 0.0046, 0.0045,  ..., 0.0054, 0.0057, 0.0059],
         [0.0046, 0.0045, 0.0045,  ..., 0.0054, 0.0057, 0.0059],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0052, 0.0051, 0.0050,  ..., 0.0046, 0.0045, 0.0044],
         [0.0049, 0.0049, 0.0048,  ..., 0.0048, 0.0047, 0.0046],
         [0.0050, 0.0049, 0.0048,  ..., 0.0048, 0.0046, 0.0045],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
       device='cuda:0', grad_fn=)
       
sort_ind 
	tensor([30, 22,  2, 18, 55, 40, 52, 43, 25, 15,  1, 19, 38, 35, 41, 57, 11, 63,
        62, 26, 36, 33, 45, 61, 32, 28, 58, 47, 14, 12, 48, 59, 31, 56, 37, 39,
        46, 24, 17, 10, 29, 42,  8, 53, 51, 54, 13,  3,  7,  9, 23, 49, 34,  6,
        50, 60,  5,  0, 27, 16, 21, 20, 44,  4], device='cuda:0')
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
# Since we decoded starting with , the targets are all words after , up to 
targets = caps_sorted[:, 1:]
print(targets)

输出:(输出结果应该是对以下的张亮进行)

tensor([[  1,  99,  64,  ...,   0,   0,   0],
        [  1, 610,  42,  ...,   0,   0,   0],
        [  1,  87,   8,  ...,   0,   0,   0],
        ...,
        [  1,  46,  54,  ...,   0,   0,   0],
        [262, 105,  67,  ...,   0,   0,   0],
        [114,  22, 902,  ...,   0,   0,   0]], device='cuda:0')

输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5701364.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存