import torch from model import Model # train for one epoch to learn unique features def train(net, data_loader, train_optimizer): net.train() total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) for pos_1, pos_2, target in train_bar: pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True) feature_1, out_1 = net(pos_1) feature_2, out_2 = net(pos_2) # [2*B, D] out = torch.cat([out_1, out_2], dim=0) # [2*B, 2*B] sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature) mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool() # [2*B, 2*B-1] sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1) # compute loss pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) # [2*B] pos_sim = torch.cat([pos_sim, pos_sim], dim=0) loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() train_optimizer.zero_grad() loss.backward() train_optimizer.step() total_num += batch_size total_loss += loss.item() * batch_size train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num)) return total_loss / total_num ———————————————— 版权声明:本文为CSDN博主「xxxxxxxxxx13」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 原文链接:https://blog.csdn.net/xxxxxxxxxx13/article/details/110820373
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)