小黑demo:TransR

小黑demo:TransR,第1张

小黑demo:TransR


import torch
import torch.nn as nn
import torch.nn.functional as F
class TransR(nn.Module):
    def __init__(self,ent_tot,rel_tot,dim_e = 100,dim_r = 100,p_norm = 1,norm_flag = True,rand_init = False,margin = None):
        super(TransR,self).__init__()
        self.ent_tot = ent_tot
        self.rel_tot = rel_tot
        self.dim_e = dim_e
        self.dim_r = dim_r
        self.norm_flag = norm_flag
        self.p_norm = p_norm
        self.rand_init = rand_init
        
        self.ent_embeddings = nn.Embedding(self.ent_tot,self.dim_e)
        self.rel_embeddings = nn.Embedding(self.rel_tot,self.dim_r)
        nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
        nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
        # 增加参数,每一个relation对应的entity空间到relation空间的映射矩阵M_r
        self.transfer_matrix = nn.Embedding(self.rel_tot,self.dim_e * self.dim_r)
        if not self.rand_init:
            identity = torch.zeros(self.dim_e,self.dim_r)
            for i in range(min(self.dim_e,self.dim_r)):
                identity[i][i] = 1
            identity = identity.view(self.dim_r * self.dim_e)
            for i in range(self.rel_tot):
                self.transfer_matrix.weight.data[i] = identity
        else:
            nn.init.xavier_uniform_(self.transfer_matrix.weight.data)
        if margin != None:
            self.margin = nn.Parameter(torch.Tensor([margin]))
            self.margin.requires_grad = False
            self.margin_flag = True
        else:
            self.margin_flag = False
    def _calc(self,h,t,r,mode):
        if self.norm_flag:
            h = F.normalize(h,2,-1)
            r = F.normalize(r,2,-1)
            t = F.normalize(t,2,-1)
        if mode != 'normal':
            h = h.view(-1,r.shape[0],h.shape[-1])
            t = t.view(-1,r.shape[0],t.shape[0])
            r = r.view(-1,r.shape[0],r.shape[-1])
        if mode == 'head_batch':
            score = h + (r - t)
        else:
            score = (h + r) - t
        score = torch.norm(score,self.p_norm,-1).flatten()
        return score    # [k]
    def _transfer(self,e,r_transfer):
        r_transfer = r_transfer.view(-1,self.dim_e,self.dim_r)    # [t,dim_e,dim_r]
        if e.shape[0] != r_transfer.shape[0]:
            e = e.view(-1,r_transfer.shape[0],self.dim_e).permute(1,0,2)
            e = torch.matmul(e,r_transfer).permute(1,0,2)
        else:
            e = e.view(-1,1,self.dim_e)    # [t,1,dim_e]
            e = torch.matmul(e,r_transfer)    # [t,1,dim_r]
        return e.view(-1,self.dim_r)
    def forward(self,data):
        batch_h = data['batch_h']
        batch_t = data['batch_t']
        batch_r = data['batch_r']
        mode = data['mode']
        
        h = self.ent_embeddings(batch_h)
        t = self.ent_embeddings(batch_t)
        r = self.rel_embeddings(batch_r)
        r_transfer = self.transfer_matrix(batch_r)
        # 计算h通过矩阵M_r将实体映射到relation空间中
        h = self._transfer(h,r_transfer)
        t = self._transfer(t,r_transfer)
        score = self._calc(h,t,r,mode)
        if self.margin_flag:
            return self.margin - score
        else:
            return score
    def regularization(self,data):
        batch_h = data['batch_h']
        batch_t = data['batch_t']
        batch_r = data['batch_r']
        
        h = self.ent_embeddings(batch_h)
        t = self.ent_embeddings(batch_t)
        r = self.rel_embeddings(batch_r)
        r_transfer = self.transfer_matrix(batch_r)
        regul = (torch.mean(h ** 2) +
                 torch.mean(t ** 2) +
                 torch.mean(r ** 2) +
                 torch.mean(r_transfer ** 2)
                ) / 4
        return regul
    def predict(self,data):
        score = self.forward(data)
        if self.margin_flag:
            score = self.margin - score
            return score.cpu().data.numpy()
        else:
            return score.cpu().data.numpy()
transr = TransR(train_dataloader.get_ent_tot(),train_dataloader.get_rel_tot())
train_dataloader = TrainDataLoader(
    in_path = "./benchmarks/FB15K237_tiny/",
    nbatches = 100,
    threads = 8,
    # 负采样
    sampling_mode = 'normal',
    # bern构建负样本方式
    bern_flag = 1,
    # 负样本同replace entities
    neg_ent = 25,
    neg_rel = 0
)
# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/FB15K237_tiny/", "link")
loss = MarginLoss()
model = NegativeSampling(transr,loss,batch_size=train_dataloader.batch_size)
trainer = Traniner(model = model,data_loader = train_dataloader)
trainer.run()   


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5720856.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-18
下一篇 2022-12-18

发表评论

登录后才能评论

评论列表(0条)

保存