import torch import torch.nn as nn import torch.nn.functional as F class TransR(nn.Module): def __init__(self,ent_tot,rel_tot,dim_e = 100,dim_r = 100,p_norm = 1,norm_flag = True,rand_init = False,margin = None): super(TransR,self).__init__() self.ent_tot = ent_tot self.rel_tot = rel_tot self.dim_e = dim_e self.dim_r = dim_r self.norm_flag = norm_flag self.p_norm = p_norm self.rand_init = rand_init self.ent_embeddings = nn.Embedding(self.ent_tot,self.dim_e) self.rel_embeddings = nn.Embedding(self.rel_tot,self.dim_r) nn.init.xavier_uniform_(self.ent_embeddings.weight.data) nn.init.xavier_uniform_(self.rel_embeddings.weight.data) # 增加参数,每一个relation对应的entity空间到relation空间的映射矩阵M_r self.transfer_matrix = nn.Embedding(self.rel_tot,self.dim_e * self.dim_r) if not self.rand_init: identity = torch.zeros(self.dim_e,self.dim_r) for i in range(min(self.dim_e,self.dim_r)): identity[i][i] = 1 identity = identity.view(self.dim_r * self.dim_e) for i in range(self.rel_tot): self.transfer_matrix.weight.data[i] = identity else: nn.init.xavier_uniform_(self.transfer_matrix.weight.data) if margin != None: self.margin = nn.Parameter(torch.Tensor([margin])) self.margin.requires_grad = False self.margin_flag = True else: self.margin_flag = False def _calc(self,h,t,r,mode): if self.norm_flag: h = F.normalize(h,2,-1) r = F.normalize(r,2,-1) t = F.normalize(t,2,-1) if mode != 'normal': h = h.view(-1,r.shape[0],h.shape[-1]) t = t.view(-1,r.shape[0],t.shape[0]) r = r.view(-1,r.shape[0],r.shape[-1]) if mode == 'head_batch': score = h + (r - t) else: score = (h + r) - t score = torch.norm(score,self.p_norm,-1).flatten() return score # [k] def _transfer(self,e,r_transfer): r_transfer = r_transfer.view(-1,self.dim_e,self.dim_r) # [t,dim_e,dim_r] if e.shape[0] != r_transfer.shape[0]: e = e.view(-1,r_transfer.shape[0],self.dim_e).permute(1,0,2) e = torch.matmul(e,r_transfer).permute(1,0,2) else: e = e.view(-1,1,self.dim_e) # [t,1,dim_e] e = torch.matmul(e,r_transfer) # [t,1,dim_r] return e.view(-1,self.dim_r) def forward(self,data): batch_h = data['batch_h'] batch_t = data['batch_t'] batch_r = data['batch_r'] mode = data['mode'] h = self.ent_embeddings(batch_h) t = self.ent_embeddings(batch_t) r = self.rel_embeddings(batch_r) r_transfer = self.transfer_matrix(batch_r) # 计算h通过矩阵M_r将实体映射到relation空间中 h = self._transfer(h,r_transfer) t = self._transfer(t,r_transfer) score = self._calc(h,t,r,mode) if self.margin_flag: return self.margin - score else: return score def regularization(self,data): batch_h = data['batch_h'] batch_t = data['batch_t'] batch_r = data['batch_r'] h = self.ent_embeddings(batch_h) t = self.ent_embeddings(batch_t) r = self.rel_embeddings(batch_r) r_transfer = self.transfer_matrix(batch_r) regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2) + torch.mean(r_transfer ** 2) ) / 4 return regul def predict(self,data): score = self.forward(data) if self.margin_flag: score = self.margin - score return score.cpu().data.numpy() else: return score.cpu().data.numpy() transr = TransR(train_dataloader.get_ent_tot(),train_dataloader.get_rel_tot()) train_dataloader = TrainDataLoader( in_path = "./benchmarks/FB15K237_tiny/", nbatches = 100, threads = 8, # 负采样 sampling_mode = 'normal', # bern构建负样本方式 bern_flag = 1, # 负样本同replace entities neg_ent = 25, neg_rel = 0 ) # dataloader for test test_dataloader = TestDataLoader("./benchmarks/FB15K237_tiny/", "link") loss = MarginLoss() model = NegativeSampling(transr,loss,batch_size=train_dataloader.batch_size) trainer = Traniner(model = model,data_loader = train_dataloader) trainer.run()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)