问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

YouTube双塔模型——基于流数据的采样偏差修正

创作时间:
作者:
@小白创作中心

YouTube双塔模型——基于流数据的采样偏差修正

引用
CSDN
1.
https://blog.csdn.net/Shallowmm/article/details/137211297

一、概述(出发点)

双塔模型通过采样负样本来优化损失函数,但是这种方式会受到采样偏差影响,影响模型性能,特别是在样本分布极度倾斜的情况下。YouTube中的视频数据是流数据,新增的Item无法包含在固定的语料库,因此需要在batch中进行负采样并计算in-batch softmax。从流数据中估计item的采样概率,并应用到采样偏差的修正上是改善模型性能的关键。

二、具体贡献

2.1 针对batch negative sampling的bias修正

在双塔模型中模型的输出用user embeding和item embeding的点乘来表示,用于衡量二者的相似度。双塔模型的训练方式分为三种:point-wise、piar-wise和list-wise。YouTube的双塔模型采用的是list-wise方式来进行训练。list-wise的训练可以看做是一个经典的多分类问题,我们通常采用softmax计算概率,然后使用交叉熵作为损失函数,优化目标是让正样本的分数尽可能高。

作者采用了加权对数似然作为损失函数,这里的$r_i$表示$(x_i, y_i)$的奖励,每一个label都是同等重要的在分类任务中,$r_i=1$表示正样例,$r_i=0$表示负样例;在推荐系统中,$r_i$的含义则可以进一步拓展,例如若用户在一个视频上观看的时间较长,则可以设置得较大,表示用户更喜欢这些视频。

由于YouTube这类产品面对的数据量非常庞大,因此不可能针对整个语料库中的item进行softmax,然后再计算交叉熵损失。因此一般会采样出一个样本子集,在这个样本子集上计算softmax(sampled softmax),但是不同于MLP模型负样本从固定的语料库中采样,面对实际业务中的流数据,负采样只能在batch中进行(batch negative sampling)。

但是这种采样方式会造成较大的bias,由于batch negative sampling隐式地使用了基于item出现频率的采样分布,因此对于热门的item,它被采样的概率更大,会被更多地作为负样本,从而热门样本会被过度惩罚。引用sampled softmax的做法,作者对user embeding 和item embeding的内积进行修正:

$p_j$表示batch中item j的采样概率,而$p_j$的估计则是后续工作的重点,在引入修正项之后,模型的训练就可以通过SGD来进行优化。

Tricks

作者在内积计算的部分还采用了两个tricks:

  • Normalization:对user侧和item侧的输出进行L2标准化
  • Temperature:引入温度参数,对输出进行平滑处理

如何理解温度参数$\tau$?

假设向量$s=[1,2,3]$,的情形就是直接进行softmax,得到的概率分布为[0.09,0.24,0.67],逐渐提高的值,可以发现概率分布越来越平滑,反之则越陡峭。文章后面也针对不同温度系数进行了实验。

一般来讲,如果温度系数设的越大,logits分布变得越平滑,那么log-softmax损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是它控制了模型对负样本的区分度。

$\tau$
$s_0$
$s_1$
$s_2$
sum
1
0.090031
0.244728
0.665241
1
2
0.186324
0.307196
0.50648
1
4
0.254275
0.326496
0.419229
1
0.5
0.015876
0.11731
0.866813
1

2.2 in-batch item采样概率的估计

该部分是论文的重点,作者提出的Streaming Frequency Estimation算法通过计算item连续两次被采样的平均间隔,得到采样概率的估计$p_j = 1 / B[h(j)]$

算法使用了两个数组A、B和一个哈希函数h,数组B记录当前平均采样间隔,数组A记录上次一采样的时间,利用A来辅助更新B

注意到哈希函数的输出空间大小为H,当H<M时会存在hash冲突,从而导致item采样概率的过度估计(因为A[h(y)]更新得很频繁,t-A[h(y)]也就偏小),论文的改进方法是使用multi-hash,即使用多组A、B和哈希函数h,最终的计算结果取最大的B[h(y)]

三、实验结果

作者首先测试了不同学习率和multi-hash对于概率估计算法的影响,并在Wikipedia dataset和YouTube上验证了引入修正项后moxing的有效性

3.1 Simulation on Frequency Estimation

实验表明:

  1. 较高的学习率导致更快的收敛时间,但误差相对较高
  2. multi-hash可以有效降低误差,即使在相同数量的参数下


3.2 Wikipedia Page Retrieval

经过修正的模型Recall@K表现明显优于未经过修正的模型和mse-gramian

3.3 YouTube Experiment

相较于基线模型,经过修正的模型在YouTube视频推荐中的表现更好

除了离线实验,作者还进行了在线实验,并利用了reward来训练模型,从而真实反映用户对于视频的参与程度

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号