criterion = nn.CrossEntropyLoss().to(device) print(criterion)
输出:
CrossEntropyLoss()#交叉熵损失函数
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print(normalize)
输出:
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
for epoch in range(start_epoch, epochs): print(start_epoch) print(epochs) break
输出:
0 2datasets.py文件中class CaptionDataset(Dataset)查看captions
data_folder = 'image data/dataset_Flickr8k/' data_name = 'flickr8k_5_cap_per_img_5_min_word_freq' print(os.path.join(data_folder, 'TRAIN' + '_CAPTIONS_' + data_name + '.json') with open(os.path.join(data_folder, 'TRAIN' + '_CAPTIONS_' + data_name + '.json'), 'r') as j: captions = json.load(j) print(captions)
输出:
/***/image data/dataset_Flickr8k/TRAIN_CAPTIONS_flickr8k_5_cap_per_img_5_min_word_freq.json [ ... [2631, 198, 196, 165, 44, 1, 36, 288, 197, 104, 198, 2630, 133, 428, 408, 702, 2630, 38, 1476, 38, 198, 274, 87, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 1, 2630, 44, 1, 256, 36, 288, 197, 104, 1, 428, 702, 2630, 38, 1, 89, 8, 1, 2, 56, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 9, 32, 8, 9, 733, 734, 4, 2630, 8, 192, 33, 34, 9, 91, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 1, 29, 32, 8, 1, 733, 734, 721, 168, 8, 9, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 1, 120, 32, 8, 1, 544, 734, 140, 155, 1, 26, 27, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 1, 120, 32, 368, 162, 27, 9, 28, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2631, 1, 29, 32, 64, 1, 2, 2630, 4, 368, 8, 9, 78, 42, 28, 8, 9, 366, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ]datasets.py文件中class CaptionDataset(Dataset)查看caplens
data_folder = 'image data/dataset_Flickr8k/' data_name = 'flickr8k_5_cap_per_img_5_min_word_freq' print(os.path.join(data_folder, 'TRAIN' + '_CAPLENS_' + data_name + '.json')) with open(os.path.join(data_folder, 'TRAIN' + '_CAPLENS_' + data_name + '.json'), 'r') as j: caplens = json.load(j) print(caplens)
输出:
/***/image data/dataset_Flickr8k/TRAIN_CAPLENS_flickr8k_5_cap_per_img_5_min_word_freq.json [...14, 10, 16, 11, 12, 8, 14, 8, 15, 9, 18, 15, 14, 24, 22, 17, 14, 15, 10, 19]查看输入encoder的图片输出结果(很重要)
resnet101对一个batch的图片进行编码,输出的张量为torch.Size([64, 14, 14, 2048])
1个batch大小为64 通道数为14 高为14 宽为2048
for i, (imgs, caps, caplens) in enumerate(train_loader): data_time.update(time.time() - start) # print(data_time)## Move to GPU, if available imgs = imgs.to(device) caps = caps.to(device) # [2631, 1, 2589, 4, 1330, 1, 10, 365, 211, 408, 856, 34, 1, 543, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] caplens = caplens.to(device) imgs = encoder(imgs) print(imgs) break
输出:
torch.Size([64, 14, 14, 2048]) tensor( [ [ [[9.0361e-02, 5.2851e-01, 2.1546e+00, ..., 0.0000e+00, 0.0000e+00, 3.5757e-01], [2.1035e-01, 6.9238e-01, 2.0149e+00, ..., 0.0000e+00, 9.5086e-01, 2.3448e-01], [3.3035e-01, 8.5626e-01, 1.8753e+00, ..., 0.0000e+00, 1.9017e+00, 1.1140e-01], ..., [0.0000e+00, 2.1303e-01, 3.7578e-01, ..., 1.7350e-02, 6.0041e-02, 3.7648e-01], [0.0000e+00, 1.9494e-01, 2.9860e-01, ..., 8.6750e-03, 3.0021e-02, 4.0226e-01], [0.0000e+00, 1.7685e-01, 2.2142e-01, ..., 0.0000e+00, 0.0000e+00, 4.2803e-01] ], [[4.5180e-02, 8.1279e-01, 1.0773e+00, ..., 1.8461e-01, 8.3581e-01, 1.7878e-01], [1.0518e-01, 9.8500e-01, 1.0075e+00, ..., 9.2305e-02, 1.7277e+00, 1.1724e-01], [1.6517e-01, 1.1572e+00, 9.3763e-01, ..., 0.0000e+00, 2.6196e+00, 5.5702e-02], ..., [0.0000e+00, 5.1433e-01, 1.8789e-01, ..., 8.6750e-03, 1.1592e-01, 1.8824e-01], [0.0000e+00, 5.3312e-01, 1.4930e-01, ..., 4.3375e-03, 5.7961e-02, 2.0113e-01], [0.0000e+00, 5.5192e-01, 1.1071e-01, ..., 0.0000e+00, 0.0000e+00, 2.1402e-01] ], [[0.0000e+00, 1.0971e+00, 0.0000e+00, ..., 3.6922e-01,1.6716e+00, 0.0000e+00], [0.0000e+00, 1.2776e+00, 0.0000e+00, ..., 1.8461e-01,2.5045e+00, 0.0000e+00], [0.0000e+00, 1.4582e+00, 0.0000e+00, ..., 0.0000e+00,3.3375e+00, 0.0000e+00], ..., [0.0000e+00, 8.1563e-01, 0.0000e+00, ..., 0.0000e+00,1.7180e-01, 0.0000e+00], [0.0000e+00, 8.7131e-01, 0.0000e+00, ..., 0.0000e+00,8.5902e-02, 0.0000e+00], [0.0000e+00, 9.2699e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00] ], ..., [[0.0000e+00, 1.1359e+00, 0.0000e+00, ..., 2.0694e-01,1.1926e+00, 9.9706e-01], [0.0000e+00, 1.1718e+00, 0.0000e+00, ..., 1.8501e-01,1.7524e+00, 9.7201e-01], [0.0000e+00, 1.2076e+00, 0.0000e+00, ..., 1.6309e-01,2.3122e+00, 9.4696e-01], ..., [0.0000e+00, 4.0909e-01, 0.0000e+00, ..., 2.9820e-02,8.7141e-01, 0.0000e+00], [0.0000e+00, 5.2929e-01, 1.7152e-01, ..., 1.4910e-02,8.3648e-01, 0.0000e+00], [0.0000e+00, 6.4950e-01, 3.4304e-01, ..., 0.0000e+00,8.0156e-01, 0.0000e+00] ], [[0.0000e+00, 1.0725e+00, 0.0000e+00, ..., 1.0347e-01,9.2755e-01, 7.7783e-01], [0.0000e+00, 8.8376e-01, 0.0000e+00, ..., 9.2507e-02,1.2014e+00, 7.1052e-01], [0.0000e+00, 6.9500e-01, 0.0000e+00, ..., 8.1545e-02,1.4753e+00, 6.4321e-01], ..., [0.0000e+00, 6.4680e-01, 2.7190e-01, ..., 1.4910e-02,6.5424e-01, 0.0000e+00], [0.0000e+00, 6.4789e-01, 5.2873e-01, ..., 7.4550e-03,7.9419e-01, 0.0000e+00], [0.0000e+00, 6.4898e-01, 7.8556e-01, ..., 0.0000e+00,9.3414e-01, 0.0000e+00] ], [[0.0000e+00, 1.0091e+00, 0.0000e+00, ..., 0.0000e+00,6.6249e-01, 5.5859e-01], [0.0000e+00, 5.9575e-01, 0.0000e+00, ..., 0.0000e+00,6.5037e-01, 4.4903e-01], [0.0000e+00, 1.8237e-01, 0.0000e+00, ..., 0.0000e+00,6.3826e-01, 3.3946e-01], ..., [0.0000e+00, 8.8452e-01, 5.4380e-01, ..., 0.0000e+00,4.3707e-01, 0.0000e+00], [0.0000e+00, 7.6649e-01, 8.8594e-01, ..., 0.0000e+00,7.5190e-01, 0.0000e+00], [0.0000e+00, 6.4845e-01, 1.2281e+00, ..., 0.0000e+00,1.0667e+00, 0.0000e+00] ] ], [ [[3.1328e-01, 4.3026e-01, 1.2069e+00, ..., 3.4335e-01,0.0000e+00, 1.6301e-01], [1.5664e-01, 7.5903e-01, 1.1955e+00, ..., 3.6088e-01,0.0000e+00, 2.1009e-01], [0.0000e+00, 1.0878e+00, 1.1842e+00, ..., 3.7841e-01,0.0000e+00, 2.5717e-01], ..., [0.0000e+00, 9.5698e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.5350e+00], [0.0000e+00, 1.1680e+00, 6.2503e-01, ..., 0.0000e+00,0.0000e+00, 1.5504e+00], [0.0000e+00, 1.3790e+00, 1.2501e+00, ..., 0.0000e+00,0.0000e+00, 1.5659e+00] ], [[1.5664e-01, 5.1228e-01, 2.1516e+00, ..., 1.7679e-01,9.3442e-02, 1.5183e-01], [7.8320e-02, 6.2838e-01, 1.9712e+00, ..., 1.9270e-01,4.6721e-02, 1.9961e-01], [0.0000e+00, 7.4448e-01, 1.7908e+00, ..., 2.0861e-01,0.0000e+00, 2.4739e-01], ..., [3.5671e-02, 1.1973e+00, 4.5830e-01, ..., 0.0000e+00,0.0000e+00, 1.3950e+00], [1.0273e-01, 1.1930e+00, 7.2625e-01, ..., 0.0000e+00,0.0000e+00, 1.3092e+00], [1.6979e-01, 1.1886e+00, 9.9420e-01, ..., 0.0000e+00,0.0000e+00, 1.2233e+00]], [[0.0000e+00, 5.9430e-01, 3.0962e+00, ..., 1.0227e-02,1.8688e-01, 1.4064e-01], [0.0000e+00, 4.9773e-01, 2.7468e+00, ..., 2.4522e-02,9.3442e-02, 1.8913e-01], [0.0000e+00, 4.0116e-01, 2.3975e+00, ..., 3.8816e-02,0.0000e+00, 2.3761e-01], ..., [7.1341e-02, 1.4377e+00, 9.1659e-01, ..., 0.0000e+00,0.0000e+00, 1.2551e+00], [2.0546e-01, 1.2180e+00, 8.2746e-01, ..., 0.0000e+00,0.0000e+00, 1.0679e+00], [3.3958e-01, 9.9825e-01, 7.3833e-01, ..., 0.0000e+00,0.0000e+00, 8.8069e-01]], ..., [[2.1133e-01, 6.3768e-01, 2.9080e+00, ..., 0.0000e+00,4.4643e-01, 7.4818e-01], [2.8350e-01, 6.8930e-01, 2.8615e+00, ..., 0.0000e+00,3.0719e-01, 7.3427e-01], [3.5566e-01, 7.4091e-01, 2.8149e+00, ..., 0.0000e+00,1.6795e-01, 7.2036e-01], ..., [1.4313e+00, 1.8079e+00, 1.1452e+00, ..., 0.0000e+00,2.4840e-03, 8.6259e-01], [1.0502e+00, 1.9889e+00, 1.4327e+00, ..., 0.0000e+00,6.1725e-02, 1.2689e+00], [6.6908e-01, 2.1700e+00, 1.7202e+00, ..., 0.0000e+00,1.2097e-01, 1.6752e+00]], [[1.5723e-01, 1.1452e+00, 3.1622e+00, ..., 1.5288e-01,2.2322e-01, 5.7149e-01], [1.6753e-01, 1.0057e+00, 3.2275e+00, ..., 9.2347e-02,2.3731e-01, 6.1787e-01], [1.7783e-01, 8.6625e-01, 3.2928e+00, ..., 3.1819e-02,2.5141e-01, 6.6425e-01], ..., [1.0436e+00, 1.6407e+00, 1.6283e+00, ..., 0.0000e+00,1.0191e-01, 8.7368e-01], [7.4215e-01, 1.8037e+00, 1.7287e+00, ..., 0.0000e+00,1.2338e-01, 1.1357e+00], [4.4071e-01, 1.9668e+00, 1.8292e+00, ..., 0.0000e+00,1.4485e-01, 1.3978e+00]], [[1.0312e-01, 1.6527e+00, 3.4163e+00, ..., 3.0575e-01,0.0000e+00, 3.9481e-01], [5.1561e-02, 1.3221e+00, 3.5935e+00, ..., 1.8469e-01,1.6743e-01, 5.0147e-01], [0.0000e+00, 9.9158e-01, 3.7707e+00, ..., 6.3637e-02,3.3487e-01, 6.0814e-01], ..., [6.5591e-01, 1.4736e+00, 2.1113e+00, ..., 0.0000e+00,2.0134e-01, 8.8477e-01], [4.3412e-01, 1.6185e+00, 2.0248e+00, ..., 0.0000e+00,1.8504e-01, 1.0025e+00], [2.1233e-01, 1.7635e+00, 1.9382e+00, ..., 0.0000e+00,1.6873e-01, 1.1203e+00] ] ], [ [[0.0000e+00, 0.0000e+00, 9.5545e-01, ..., 8.0960e-01,0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 8.0015e-01, ..., 5.5233e-01,0.0000e+00, 7.4726e-03], [0.0000e+00, 0.0000e+00, 6.4485e-01, ..., 2.9505e-01,0.0000e+00, 1.4945e-02], ..., [0.0000e+00, 3.3634e-01, 8.3794e-01, ..., 0.0000e+00,0.0000e+00, 0.0000e+00], [0.0000e+00, 3.4318e-01, 8.7203e-01, ..., 0.0000e+00,0.0000e+00, 0.0000e+00], [0.0000e+00, 3.5001e-01, 9.0612e-01, ..., 0.0000e+00,0.0000e+00, 0.0000e+00]], [[0.0000e+00, 0.0000e+00, 5.6321e-01, ..., 6.7477e-01,0.0000e+00, 9.3554e-04], [0.0000e+00, 0.0000e+00, 4.7451e-01, ..., 5.0247e-01,0.0000e+00, 3.0180e-02], [0.0000e+00, 0.0000e+00, 3.8581e-01, ..., 3.3016e-01,0.0000e+00, 5.9425e-02], ..., [0.0000e+00, 6.7129e-01, 1.1972e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00], [0.0000e+00, 5.5759e-01, 1.6487e+00, ..., 0.0000e+00,1.0493e-02, 0.0000e+00], [0.0000e+00, 4.4389e-01, 2.1002e+00, ..., 0.0000e+00,2.0986e-02, 0.0000e+00]], [[0.0000e+00, 0.0000e+00, 1.7097e-01, ..., 5.3994e-01,0.0000e+00, 1.8711e-03], [0.0000e+00, 0.0000e+00, 1.4887e-01, ..., 4.5260e-01,0.0000e+00, 5.2887e-02], [0.0000e+00, 0.0000e+00, 1.2677e-01, ..., 3.6527e-01,0.0000e+00, 1.0390e-01], ..., [0.0000e+00, 1.0062e+00, 1.5565e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00], [0.0000e+00, 7.7201e-01, 2.4254e+00, ..., 0.0000e+00,2.0986e-02, 0.0000e+00], [0.0000e+00, 5.3778e-01, 3.2942e+00, ..., 0.0000e+00,4.1972e-02, 0.0000e+00]], ..., [[1.7784e-03, 6.2061e-02, 4.6169e-01, ..., 3.2132e-01,0.0000e+00, 1.1172e-02], [1.2874e-01, 9.7284e-02, 2.3084e-01, ..., 3.0055e-01,0.0000e+00, 6.3651e-02], [2.5569e-01, 1.3251e-01, 0.0000e+00, ..., 2.7979e-01,0.0000e+00, 1.1613e-01], ..., [5.3011e-01, 1.0567e-01, 0.0000e+00, ..., 4.1936e-01,5.7577e-02, 1.2729e-01], [4.6102e-01, 1.8947e-01, 1.7549e-01, ..., 3.3047e-01,9.7484e-02, 2.2460e-01], [3.9193e-01, 2.7327e-01, 3.5098e-01, ..., 2.4158e-01,1.3739e-01, 3.2192e-01]], [[1.8239e-01, 9.0274e-02, 1.1396e+00, ..., 3.3605e-01,0.0000e+00, 5.5860e-03], [2.7868e-01, 1.1954e-01, 8.2689e-01, ..., 3.8124e-01,0.0000e+00, 1.0846e-01], [3.7498e-01, 1.4880e-01, 5.1420e-01, ..., 4.2642e-01,0.0000e+00, 2.1132e-01], ..., [4.7168e-01, 2.4409e-01, 4.4172e-01, ..., 4.4215e-01,9.2666e-02, 2.3359e-01], [4.8129e-01, 3.6893e-01, 6.0175e-01, ..., 3.4435e-01,1.7973e-01, 2.9912e-01], [4.9090e-01, 4.9378e-01, 7.6177e-01, ..., 2.4655e-01,2.6680e-01, 3.6464e-01]], [[3.6300e-01, 1.1849e-01, 1.8174e+00, ..., 3.5079e-01,0.0000e+00, 0.0000e+00], [4.2863e-01, 1.4179e-01, 1.4229e+00, ..., 4.6192e-01,0.0000e+00, 1.5326e-01], [4.9426e-01, 1.6509e-01, 1.0284e+00, ..., 5.7305e-01,0.0000e+00, 3.0652e-01], ..., [4.1326e-01, 3.8251e-01, 8.8344e-01, ..., 4.6493e-01,1.2776e-01, 3.3989e-01], [5.0157e-01, 5.4840e-01, 1.0280e+00, ..., 3.5823e-01,2.6198e-01, 3.7363e-01], [5.8988e-01, 7.1429e-01, 1.1726e+00, ..., 2.5153e-01,3.9621e-01, 4.0737e-01] ] ], ..., [[[2.4759e-01, 0.0000e+00, 0.0000e+00, ..., 4.0495e-02,1.3125e-01, 0.0000e+00], [2.2211e-01, 0.0000e+00, 0.0000e+00, ..., 2.0248e-02,2.8641e-01, 2.7549e-02], [1.9663e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,4.4156e-01, 5.5099e-02], ..., [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,4.2970e-02, 1.7396e-01], [0.0000e+00, 9.1169e-02, 0.0000e+00, ..., 0.0000e+00,2.1485e-02, 1.5837e-01], [0.0000e+00, 1.8234e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.4279e-01] ], [[1.4166e-01, 0.0000e+00, 0.0000e+00, ..., 7.0862e-02,6.5624e-02, 0.0000e+00], [1.3300e-01, 0.0000e+00, 0.0000e+00, ..., 1.2258e-01,2.0074e-01, 1.3775e-02], [1.2434e-01, 0.0000e+00, 0.0000e+00, ..., 1.7431e-01,3.3585e-01, 2.7549e-02], ..., [0.0000e+00, 9.8729e-02, 0.0000e+00, ..., 0.0000e+00,3.9813e-01, 2.7212e-01], [0.0000e+00, 1.9067e-01, 0.0000e+00, ..., 0.0000e+00,1.9907e-01, 1.7175e-01], [0.0000e+00, 2.8260e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 7.1393e-02] ], [[3.5734e-02, 0.0000e+00, 0.0000e+00, ..., 1.0123e-01,0.0000e+00, 0.0000e+00], [4.3895e-02, 0.0000e+00, 0.0000e+00, ..., 2.2492e-01,1.1507e-01, 0.0000e+00], [5.2056e-02, 0.0000e+00, 0.0000e+00, ..., 3.4861e-01,2.3014e-01, 0.0000e+00], ..., [0.0000e+00, 1.9746e-01, 0.0000e+00, ..., 0.0000e+00,7.5330e-01, 3.7028e-01], [0.0000e+00, 2.9016e-01, 0.0000e+00, ..., 0.0000e+00,3.7665e-01, 1.8514e-01], [0.0000e+00, 3.8287e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00] ], ..., [[4.6125e-02, 3.5640e-01, 2.0939e+00, ..., 0.0000e+00,0.0000e+00, 5.0760e-02], [2.3063e-02, 3.1719e-01, 1.5483e+00, ..., 1.7244e-02,0.0000e+00, 2.5380e-02], [0.0000e+00, 2.7799e-01, 1.0028e+00, ..., 3.4487e-02,0.0000e+00, 0.0000e+00], ..., [7.0788e-01, 2.2093e-01, 0.0000e+00, ..., 3.0831e-01,0.0000e+00, 1.2356e+00], [1.1000e+00, 1.8086e-01, 0.0000e+00, ..., 3.0003e-01,0.0000e+00, 1.3106e+00], [1.4922e+00, 1.4078e-01, 0.0000e+00, ..., 2.9175e-01,0.0000e+00, 1.3855e+00] ], [[2.3063e-02, 3.1117e-01, 1.5658e+00, ..., 0.0000e+00,1.3866e-01, 2.5380e-02], [1.1531e-02, 2.9614e-01, 1.3523e+00, ..., 8.6218e-03,6.9332e-02, 1.2690e-02], [0.0000e+00, 2.8111e-01, 1.1389e+00, ..., 1.7244e-02,0.0000e+00, 0.0000e+00], ..., [6.4155e-01, 1.1047e-01, 0.0000e+00, ..., 2.0347e-01,0.0000e+00, 9.2113e-01], [8.8613e-01, 9.0429e-02, 0.0000e+00, ..., 2.4763e-01,0.0000e+00, 1.0280e+00], [1.1307e+00, 7.0391e-02, 0.0000e+00, ..., 2.9179e-01,0.0000e+00, 1.1349e+00] ], [[0.0000e+00, 2.6594e-01, 1.0377e+00, ..., 0.0000e+00,2.7733e-01, 0.0000e+00], [0.0000e+00, 2.7508e-01, 1.1563e+00, ..., 0.0000e+00,1.3866e-01, 0.0000e+00], [0.0000e+00, 2.8422e-01, 1.2750e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00], ..., [5.7521e-01, 0.0000e+00, 0.0000e+00, ..., 9.8626e-02,0.0000e+00, 6.0663e-01], [6.7222e-01, 0.0000e+00, 0.0000e+00, ..., 1.9523e-01,0.0000e+00, 7.4552e-01], [7.6922e-01, 0.0000e+00, 0.0000e+00, ..., 2.9183e-01,0.0000e+00, 8.8441e-01] ] ], [[[3.8767e-02, 0.0000e+00, 0.0000e+00, ..., 2.2975e-03,5.7446e-01, 0.0000e+00], [4.2504e-02, 7.7990e-02, 0.0000e+00, ..., 2.9087e-02,4.9452e-01, 0.0000e+00], [4.6241e-02, 1.5598e-01, 0.0000e+00, ..., 5.5877e-02,4.1458e-01, 0.0000e+00], ..., [0.0000e+00, 5.9729e-01, 1.7794e+00, ..., 4.1452e-01,5.4988e-01, 1.0653e+00], [0.0000e+00, 4.1452e-01, 1.5713e+00, ..., 4.3931e-01,6.4155e-01, 1.1801e+00], [0.0000e+00, 2.3176e-01, 1.3632e+00, ..., 4.6411e-01,7.3321e-01, 1.2950e+00] ], [[2.3761e-01, 0.0000e+00, 0.0000e+00, ..., 1.1487e-03,4.0259e-01, 0.0000e+00], [1.8682e-01, 7.9244e-02, 0.0000e+00, ..., 1.4544e-02,3.6260e-01, 0.0000e+00], [1.3603e-01, 1.5849e-01, 0.0000e+00, ..., 2.7939e-02,3.2261e-01, 0.0000e+00], ..., [0.0000e+00, 1.1255e+00, 8.8969e-01, ..., 7.6367e-01,9.6885e-01, 1.3478e+00], [3.4041e-02, 7.2316e-01, 8.3901e-01, ..., 5.3454e-01,1.0559e+00, 1.2100e+00], [6.8081e-02, 3.2080e-01, 7.8834e-01, ..., 3.0541e-01,1.1429e+00, 1.0722e+00] ], [[4.3645e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,2.3072e-01, 0.0000e+00], [3.3113e-01, 8.0499e-02, 0.0000e+00, ..., 0.0000e+00,2.3068e-01, 0.0000e+00], [2.2582e-01, 1.6100e-01, 0.0000e+00, ..., 0.0000e+00,2.3065e-01, 0.0000e+00], ..., [0.0000e+00, 1.6537e+00, 0.0000e+00, ..., 1.1128e+00,1.3878e+00, 1.6304e+00], [6.8081e-02, 1.0318e+00, 1.0673e-01, ..., 6.2977e-01,1.4702e+00, 1.2399e+00], [1.3616e-01, 4.0984e-01, 2.1347e-01, ..., 1.4671e-01,1.5525e+00, 8.4945e-01] ], ..., [[0.0000e+00, 2.3937e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.6858e-01], [0.0000e+00, 1.5298e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.7305e-01], [0.0000e+00, 6.6594e-01, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.7752e-01], ..., [0.0000e+00, 1.4648e+00, 0.0000e+00, ..., 1.5091e+00,0.0000e+00, 4.2158e-02], [0.0000e+00, 1.6270e+00, 0.0000e+00, ..., 1.3410e+00,0.0000e+00, 9.1624e-02], [0.0000e+00, 1.7892e+00, 0.0000e+00, ..., 1.1730e+00,0.0000e+00, 1.4109e-01]], [[1.3456e-02, 2.5342e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 9.5133e-02], [6.7281e-03, 1.8158e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 1.9829e-01], [0.0000e+00, 1.0975e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 3.0144e-01], ..., [0.0000e+00, 1.6274e+00, 8.1445e-02, ..., 1.5885e+00,1.9944e-01, 2.7497e-02], [0.0000e+00, 1.9397e+00, 4.0722e-02, ..., 1.2676e+00,1.7066e-01, 7.8321e-02], [0.0000e+00, 2.2520e+00, 0.0000e+00, ..., 9.4676e-01,1.4188e-01, 1.2914e-01]], [[2.6912e-02, 2.6746e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 2.1684e-02], [1.3456e-02, 2.1018e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 2.2353e-01], [0.0000e+00, 1.5291e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 4.2537e-01], ..., [0.0000e+00, 1.7899e+00, 1.6289e-01, ..., 1.6679e+00,3.9889e-01, 1.2835e-02], [0.0000e+00, 2.2524e+00, 8.1445e-02, ..., 1.1942e+00,3.4133e-01, 6.5017e-02], [0.0000e+00, 2.7149e+00, 0.0000e+00, ..., 7.2052e-01,2.8377e-01, 1.1720e-01] ] ], [[[0.0000e+00, 0.0000e+00, 4.0027e-01, ..., 4.8410e-01, 0.0000e+00, 2.1529e-02], [2.0282e-01, 0.0000e+00, 7.5335e-01, ..., 5.0351e-01, 0.0000e+00, 6.1349e-02], [4.0564e-01, 0.0000e+00, 1.1064e+00, ..., 5.2291e-01, 0.0000e+00, 1.0117e-01], ..., [2.1032e-01, 0.0000e+00, 0.0000e+00, ..., 5.1751e-01, 0.0000e+00, 0.0000e+00], [1.0516e-01, 0.0000e+00, 1.3890e-01, ..., 2.5875e-01, 0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 2.7780e-01, ..., 0.0000e+00, 0.0000e+00, 0.0000e+00]], [[2.7832e-01, 0.0000e+00, 2.0013e-01, ..., 5.3671e-01, 0.0000e+00, 1.6649e-01], [4.1329e-01, 5.7453e-02, 3.7668e-01, ..., 8.4852e-01, 0.0000e+00, 2.3294e-01], [5.4827e-01, 1.1491e-01, 5.5322e-01, ..., 1.1603e+00, 0.0000e+00, 2.9940e-01], ..., [1.0516e-01, 0.0000e+00, 1.3719e-03, ..., 5.2783e-01, 0.0000e+00, 6.0100e-02], [5.2580e-02, 0.0000e+00, 8.5612e-02, ..., 2.6392e-01, 0.0000e+00, 3.0050e-02], [0.0000e+00, 0.0000e+00, 1.6985e-01, ..., 0.0000e+00, 0.0000e+00, 0.0000e+00]], [[5.5663e-01, 0.0000e+00, 0.0000e+00, ..., 5.8931e-01, 0.0000e+00, 3.1144e-01], [6.2376e-01, 1.1491e-01, 0.0000e+00, ..., 1.1935e+00, 0.0000e+00, 4.0454e-01], [6.9089e-01, 2.2981e-01, 0.0000e+00, ..., 1.7977e+00, 0.0000e+00, 4.9763e-01], ..., [0.0000e+00, 0.0000e+00, 2.7439e-03, ..., 5.3816e-01, 0.0000e+00, 1.2020e-01], [0.0000e+00, 0.0000e+00, 3.2324e-02, ..., 2.6908e-01, 0.0000e+00, 6.0100e-02], [0.0000e+00, 0.0000e+00, 6.1904e-02, ..., 0.0000e+00, 0.0000e+00, 0.0000e+00]], ..., [[5.9512e-01, 0.0000e+00, 3.8417e-01, ..., 0.0000e+00, 0.0000e+00, 9.0328e-01], [7.4327e-01, 0.0000e+00, 2.3633e-01, ..., 1.6930e-02, 2.4616e-01, 8.2458e-01], [8.9143e-01, 0.0000e+00, 8.8485e-02, ..., 3.3860e-02, 4.9232e-01, 7.4588e-01], ..., [1.3619e+00, 7.0660e-01, 0.0000e+00, ..., 4.8366e-01, 4.6314e-02, 1.4390e+00], [8.5559e-01, 7.8314e-01, 0.0000e+00, ..., 4.2197e-01, 2.3157e-02, 1.1281e+00], [3.4924e-01, 8.5967e-01, 0.0000e+00, ..., 3.6027e-01, 0.0000e+00, 8.1720e-01]], [[6.1491e-01, 0.0000e+00, 7.3014e-01, ..., 4.3681e-02, 0.0000e+00, 1.0342e+00], [7.4444e-01, 0.0000e+00, 5.7319e-01, ..., 1.5873e-01, 1.2308e-01, 9.3152e-01], [8.7397e-01, 0.0000e+00, 4.1624e-01, ..., 2.7378e-01, 2.4616e-01, 8.2882e-01], ..., [1.0826e+00, 8.5710e-01, 0.0000e+00, ..., 5.1503e-01, 7.0558e-02, 1.3325e+00], [6.7622e-01, 8.8606e-01, 0.0000e+00, ..., 4.5536e-01, 1.0796e-01, 1.0618e+00], [2.6988e-01, 9.1502e-01, 0.0000e+00, ..., 3.9569e-01, 1.4537e-01, 7.9097e-01]], [[6.3469e-01, 0.0000e+00, 1.0761e+00, ..., 8.7361e-02, 0.0000e+00, 1.1651e+00], [7.4560e-01, 0.0000e+00, 9.1006e-01, ..., 3.0053e-01, 0.0000e+00, 1.0385e+00], [8.5651e-01, 0.0000e+00, 7.4400e-01, ..., 5.1370e-01, 0.0000e+00, 9.1176e-01], ..., [8.0316e-01, 1.0076e+00, 0.0000e+00, ..., 5.4639e-01, 9.4802e-02, 1.2261e+00], [4.9684e-01, 9.8899e-01, 0.0000e+00, ..., 4.8875e-01, 1.9277e-01, 9.9542e-01], [1.9051e-01, 9.7037e-01, 0.0000e+00, ..., 4.3111e-01, 2.9074e-01, 7.6474e-01]]]], device='cuda:0')
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) print('scores',scores) print('caps_sorted',caps_sorted) print('decode_lengths',decode_lengths) print('alphas',alphas) print('sort_ind',sort_ind)
输出:
scores tensor([[[ 0.1060, 0.0381, -0.0553, ..., 0.1050, 0.0229, -0.3909], [-0.2796, 0.0837, -0.1078, ..., -0.0661, -0.0108, -0.0798], [-0.1874, -0.1624, 0.2605, ..., -0.0012, -0.1666, 0.0373], ..., [-0.2718, 0.3011, 0.0513, ..., 0.0440, -0.0508, -0.1148], [ 0.2318, 0.0158, -0.1945, ..., 0.2243, -0.2355, 0.0454], [-0.2298, -0.2841, -0.0364, ..., -0.2020, 0.0268, -0.2004]], [[ 0.1225, -0.0296, -0.0828, ..., -0.0485, 0.0317, -0.0447], [ 0.0185, -0.0447, -0.0942, ..., -0.3165, -0.0886, 0.1155], [ 0.0655, -0.0573, 0.1095, ..., -0.3358, -0.1947, -0.2325], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.2505, 0.0975, -0.1628, ..., -0.3082, -0.0931, -0.3930], [-0.3288, -0.0160, -0.1472, ..., -0.4279, -0.2406, -0.0976], [-0.2872, -0.0300, 0.1982, ..., -0.4447, -0.2396, -0.4520], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], ..., [[ 0.3742, 0.3620, -0.2312, ..., 0.2165, -0.2903, -0.1331], [-0.0943, 0.2484, -0.2200, ..., -0.1846, -0.5298, -0.3780], [-0.1539, -0.1260, 0.0355, ..., -0.5587, -0.3220, -0.3822], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0985, -0.0282, 0.0497, ..., -0.3646, -0.0202, 0.0497], [-0.1268, 0.1090, -0.0767, ..., -0.2165, -0.0731, -0.1414], [-0.2086, 0.0095, 0.0765, ..., -0.3952, -0.2807, -0.0504], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[-0.0374, 0.2859, -0.1097, ..., 0.3040, -0.4128, -0.3252], [ 0.0172, 0.1231, -0.1188, ..., -0.1551, -0.1775, -0.1338], [-0.1317, 0.0389, 0.0262, ..., -0.1380, -0.0653, -0.2756], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], device='cuda:0', grad_fn=) caps_sorted tensor([[2631, 1, 87, ..., 0, 0, 0], [2631, 198, 200, ..., 0, 0, 0], [2631, 1, 99, ..., 0, 0, 0], ..., [2631, 262, 1028, ..., 0, 0, 0], [2631, 14, 15, ..., 0, 0, 0], [2631, 99, 1114, ..., 0, 0, 0]], device='cuda:0') decode_lengths [26, 20, 20, 19, 18, 17, 16, 16, 16, 15, 15, 15, 15, 14, 14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10, 10, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 8, 8, 8, 7, 7, 6] alphas tensor([[[0.0047, 0.0045, 0.0045, ..., 0.0050, 0.0048, 0.0046], [0.0048, 0.0048, 0.0048, ..., 0.0050, 0.0048, 0.0045], [0.0049, 0.0049, 0.0048, ..., 0.0050, 0.0048, 0.0046], ..., [0.0049, 0.0049, 0.0048, ..., 0.0049, 0.0048, 0.0046], [0.0050, 0.0049, 0.0048, ..., 0.0049, 0.0048, 0.0046], [0.0050, 0.0049, 0.0048, ..., 0.0049, 0.0048, 0.0046]], [[0.0051, 0.0050, 0.0049, ..., 0.0051, 0.0049, 0.0046], [0.0051, 0.0049, 0.0049, ..., 0.0049, 0.0049, 0.0047], [0.0050, 0.0049, 0.0049, ..., 0.0049, 0.0049, 0.0046], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[0.0046, 0.0048, 0.0048, ..., 0.0050, 0.0047, 0.0045], [0.0049, 0.0050, 0.0049, ..., 0.0048, 0.0046, 0.0045], [0.0048, 0.0049, 0.0049, ..., 0.0048, 0.0047, 0.0045], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], ..., [[0.0046, 0.0046, 0.0046, ..., 0.0053, 0.0053, 0.0052], [0.0043, 0.0044, 0.0043, ..., 0.0052, 0.0052, 0.0052], [0.0043, 0.0044, 0.0043, ..., 0.0052, 0.0052, 0.0052], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[0.0045, 0.0045, 0.0046, ..., 0.0051, 0.0055, 0.0057], [0.0046, 0.0046, 0.0045, ..., 0.0054, 0.0057, 0.0059], [0.0046, 0.0045, 0.0045, ..., 0.0054, 0.0057, 0.0059], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[0.0052, 0.0051, 0.0050, ..., 0.0046, 0.0045, 0.0044], [0.0049, 0.0049, 0.0048, ..., 0.0048, 0.0047, 0.0046], [0.0050, 0.0049, 0.0048, ..., 0.0048, 0.0046, 0.0045], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], device='cuda:0', grad_fn= ) sort_ind tensor([30, 22, 2, 18, 55, 40, 52, 43, 25, 15, 1, 19, 38, 35, 41, 57, 11, 63, 62, 26, 36, 33, 45, 61, 32, 28, 58, 47, 14, 12, 48, 59, 31, 56, 37, 39, 46, 24, 17, 10, 29, 42, 8, 53, 51, 54, 13, 3, 7, 9, 23, 49, 34, 6, 50, 60, 5, 0, 27, 16, 21, 20, 44, 4], device='cuda:0')
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) # Since we decoded starting with, the targets are all words after , up to targets = caps_sorted[:, 1:] print(targets)
输出:(输出结果应该是对以下的张亮进行)
tensor([[ 1, 99, 64, ..., 0, 0, 0], [ 1, 610, 42, ..., 0, 0, 0], [ 1, 87, 8, ..., 0, 0, 0], ..., [ 1, 46, 54, ..., 0, 0, 0], [262, 105, 67, ..., 0, 0, 0], [114, 22, 902, ..., 0, 0, 0]], device='cuda:0')
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
输出:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)