论文地址:arxiv
代码地址:github
注意,代码第三方复现难免有diff,建议阅读论文~
以往的做法一般是把用户的兴趣转成一个单独的embedding,作者认为用户应该有多个兴趣,比如一个人喜欢鞋子又喜欢篮球,就应该有两个兴趣中心。
因此本文的主要动机就是对多兴趣进行建模。
本文的网络结构画风和DIN简直太像哈哈哈,网络结构也很简单,从输入往上看。
上图为输入的内容,other Feature主要为用户的属性特征,比如性别年龄等,item1-N为用户的序列行为,label那一侧应该是推荐的内容吧就是target item。
举个例子,背景是预测新推荐的短视频用户是否点击,那么item1-N就是用户以往点击的N个视频,target item就是当前推荐的短视频,other feature就是用户的年龄性别那些特征。
上图是本文的核心部分,针对图可以看到,N个item变成了5个兴趣中心,其实这部分的实现就是Kmeans哈,这个就是可以简单的理解成用聚类算法把N个item聚成了5个类,类中心就是兴趣中心。
但是为啥作者的实现不是Kmeans呢,其实也很简单,sklearn的Kmeans不是用tensor写的,没法传梯度然后反向传播更新模型,作者这里看着就像用tensor简单的实现了Kmeans的功能。
我这里简单的写了下流程:
c
j
h
c^h_j
cjh是输出的兴趣中心,
c
i
l
c^l_i
cil是输入的N个item,
S
i
j
S_{ij}
Sij是一个可学习的参数,流程如下:
1、首先 b i j b_{ij} bij用正态分布初始化
2、根据图示公式计算 w i j w_{ij} wij
3、根据图示公式计算 z j h z^h_j zjh
4、根据图示公式计算 c j h c^h_j cjh
5、根据图示公式更新 b i j b_{ij} bij
6、重复2-6执行n轮
公式真没啥就是矩阵乘的运算,可以简单的理解成Kmeans就好了。
然后这里需要人为的设定聚类的中心,其实是非常不友好的,这也是Kmeans算法比较致命的缺陷,优化的话可以试试DBSCAN这样的聚类算法,根据特征的密度自动确定类中心个数可能更合理一些。
后面一节就是把特征拼接起来过FC层接损失进行分类了,这还多了一个label-aware attention,实现也很简单,就是ReLU出来的特征过两个不同的FC生成K和V,然后K和Q(Q就是target item)点乘,算softmax后计算每个K的权重然后加权到V上了。
import torch
import torch.nn as nn
import torch.nn.functional as F
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1)
row_vector = row_vector.to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask.type(dtype)
return mask
def squash(inputs):
vec_squared_norm = torch.sum(torch.square(inputs), dim=1, keepdim=True)
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / torch.sqrt(vec_squared_norm + 1e-9)
vec_squashed = scalar_factor * inputs # element-wise
return vec_squashed
class RoutingNew(nn.Module):
def __init__(self, input_seq_len, input_dim, output_dim, iteration=1, output_seq_len=1):
super(RoutingNew, self).__init__()
self.iteration = iteration
self.output_seq_len = output_seq_len
self.input_seq_len = input_seq_len
self.input_dim = input_dim
self.output_dim = output_dim
self.B_matrix = nn.init.normal_(torch.empty(1, output_seq_len, input_seq_len), mean=0, std=1)
self.B_matrix.requires_grad = False
self.S_matrix = nn.init.normal_(torch.empty(self.input_dim, self.output_dim), mean=0, std=1)
def forward(self, low_capsule, seq_len):
## seq: B * 1
## low_capsule: B * in_seq_len * in_dim
## high_capsule: B * out_seq_len * out_dim
B, _, _ = low_capsule.size()
assert torch.max(seq_len).item() <= self.input_seq_len
assert self.iteration >= 1
seq_len_tile = seq_len.repeat(1, self.output_seq_len) # B * out_seq_len
for i in range(self.iteration):
mask = sequence_mask(seq_len_tile, self.input_seq_len) # B * out_seq_len * in_seq_len
pad = torch.ones_like(mask, dtype=torch.float32) * (-2 ** 16 + 1) # B * out_seq_len * in_seq_len
pad = pad.to(mask.device)
B_tile = self.B_matrix.repeat(B, 1, 1) # B * out_seq_len * in_seq_len
B_tile = B_tile.to(mask.device)
B_mask = torch.where(mask, B_tile, pad) # B * out_seq_len * in_seq_len
W = nn.functional.softmax(B_mask, dim=-1) # B * out_seq_len * in_seq_len
self.S_matrix = self.S_matrix.to(low_capsule.device)
low_capsule_new = torch.einsum('ijk,ko->ijo', (low_capsule, self.S_matrix)) # B * in_seq_len * output_dim
high_capsule_tmp = torch.matmul(W, low_capsule_new) # B * output_seq_len * output_dim
high_capsule = squash(high_capsule_tmp) # B * output_seq_len * output_dim
with torch.no_grad(): # 这里一定要截断梯度,不然反向传播这个for循环会报错
B_delta = torch.sum(
torch.matmul(high_capsule, torch.transpose(low_capsule_new, dim0=1, dim1=2)),
dim=0, keepdim=True)
B_delta = B_delta.to(mask.device)
self.B_matrix = self.B_matrix.to(mask.device)
self.B_matrix = B_delta + self.B_matrix
return high_capsule
if __name__ == '__main__':
# 序列长16,输入维度128,输出维度256,迭代1轮,输出的兴趣中心3个
model = RoutingNew(input_seq_len=16,input_dim=128,output_dim=256,iteration=1,output_seq_len=3)
input = torch.randn((64,16,128))
# 因为序列可能不定长,需要给padding的部分mask掉,这里测试就16个全保留
seq_mask = torch.ones((64,1))*16
output = model(input,seq_mask)
print(output.shape) # 64,3,256
效果
这个效果看不太懂,看起来是变好了,但是我在工业数据测试这个模型效果很差…而且类中心设几个合理呢?感觉不是很好用的一个方法。
但是呢多个兴趣中心的这个想法是很不错的,有机会想想怎么优化这个方案~太忙了。
。
哎
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)