推荐系统(十八):优势特征蒸馏(Privileged Features Distillation)在商品推荐中的应用

在商品推荐系统中,粗排和精排环节的知识蒸馏方法主要通过复杂模型(Teacher)指导简单模型(Student)的训练,以提升粗排效果及与精排的一致性。本文将以淘宝的一篇论文《Privileged Features Distillation at Taobao Recommendations》中介绍的 PFD(Privileged Features Distillation)方法为例实现一个Demo,帮助读者学习知识蒸馏。

1.知识蒸馏方法概述

知识蒸馏诞生至今,早已不局限于粗排,而是在粗排和精排均有应用。粗排和精排的知识蒸馏核心在于通过不同形式的知识迁移(logits、排序结果、特征)提升模型效果与一致性。粗排侧重从精排获取排序偏好,而精排侧重模型压缩。实际应用中需结合业务场景选择蒸馏策略,并权衡性能与效果。本节简要介绍一下知识蒸馏方法。

一、粗排环节的典型蒸馏方法

粗排需平衡性能和效果,通常以精排为Teacher进行知识迁移,主要方法包括:

(1)Logits蒸馏

  • 原理:利用精排模型的输出logits(未归一化的预测值)作为软标签(soft label),指导粗排模型学习。通过引入温度系数(Temperature Scaling)调整软标签的分布,增强非主导类别的信息传递。
  • 损失函数:粗排模型的损失由两部分组成: Hard Loss:基于真实标签的交叉熵损失; Soft Loss:基于精排输出logits的KL散度或MSE损失。
  • 应用:美团、爱奇艺等采用两阶段训练,先训练精排Teacher,再固定其参数指导粗排Student56。

(2)排序结果蒸馏

  • 原理:直接利用精排输出的有序列表信息,构造粗排的训练样本。常见方法包括:

     1. Point-wise:将精排Top-K结果作为正样本,其余作为负样本,并引入位置权重。2. Pair-wise:从精排列表中随机抽取商品对,学习偏序关系(如BPR损失)。、3. List-wise:通过NDCG等指标对齐粗排与精排的整体排序。 
    
  • 优势:缓解样本选择偏差,增强粗排对精排排序偏好的拟合。

(3)特征蒸馏

  • 原理:迁移精排模型的中间层特征,要求粗排和精排的网络结构部分对齐。例如: 隐层特征对齐:通过MSE损失约束粗排与精排的隐层输出(如淘宝的 PFD(Privileged Features Distillation) 方法)。
  • 优势特征蒸馏:将精排使用的交叉特征等“特权特征”迁移到粗排(如用户与商品的交互特征)。
  • 应用:淘宝在 KDD 2020 提出的 PFD 方法中,精排 Teacher 使用交叉特征,粗排 Student 仅用基础特征,通过蒸馏提升效果。

二、精排环节的典型蒸馏方法

精排蒸馏主要用于模型压缩,将复杂模型(如集成模型)的能力迁移至轻量级模型:

(1)Logits蒸馏

  • 原理:与粗排类似,使用复杂精排模型的 logits 指导轻量级 Student 模型训练。例如: 阿里 Rocket Launching 框架:Teacher 和 Student 共享 Embedding 层,联合训练并通过 logits 对齐。
  • 改进:爱奇艺双 DNN 模型进一步约束 Student 隐层与 Teacher 隐层的激活值相似性。

(2)多目标蒸馏

  • 原理:将精排的多任务输出(如CTR、CVR)迁移至 Student。例如: 腾讯在 SIGIR 2021 提出通过 KL 散度对齐多任务 logits,提升粗排/召回模型的多目标一致性。
  • 损失设计:结合多任务损失和蒸馏损失,如加权交叉熵或对比学习损失。

三、关键技术与实践

(1)温度系数(Temperature)

调节 softmax 输出的平滑度,温度值越大,分布越平滑,帮助 Student 学习 Teacher 的暗知识(Dark Knowledge)。

(2)两阶段训练 vs 联合训练

  • 两阶段:先独立训练 Teacher,再固定其参数指导 Student(稳定性高)。
  • 联合训练:Teacher 和 Student 同步更新(减少耗时,但需设计梯度阻断防止相互干扰)。

(3)实际应用案例

  • 美团:通过对比学习强化粗排与精排的特征对齐,粗排CTR提升 0.15%。
  • 淘宝:优势特征蒸馏使粗排 CTR 提升 5%,精排CVR提升 2.3%。
  • 腾讯音乐:多目标蒸馏在粗排阶段实现阅读时长与点击率的联合优化。

2. PFD(Privileged Features Distillation)方法介绍

PFD(Privileged Features Distillation)方法出自论文《Privileged Features Distillation at Taobao Recommendations》。论文中描述:在离线环境下同时训练两个模型:一个学生模型以及一个教师模型。其中学生模型和原始模型完全相同,而教师模型额外利用了优势特征, 其准确率也因此更高。通过将教师模型蒸馏出的知识(Knowlege, 本文特指教师模型中最后一层的输出)传递给学生模型,可以辅助其训练以进一步提升准确率。在线上服务时,我们只抽取学生模型进行部署,因为输入不依赖于优势特征,离线、在线的一致性得以保证。在 PFD 中,所有的优势特征都被统一到教师模型作为输入,加入更多的优势特征往往能带来模型更高的准确度。

PFD 不同于常见的模型蒸馏(Model Disitillation, 简称 MD)。 在 MD 中,教师模型和学生模型处理同样的输入特征,其中教师模型会比学生模型更为复杂, 比如,教师模型会用更深的网络结构来指导使用浅层网络的学生模型进行学习。在 PFD 中,教师和学生模型会使用相同网络结构,而处理不同的输入特征。MD 和 PFD 两者的差异如下图所示。
在这里插入图片描述

如上图所示:模型蒸馏(Model Distill, 简称 MD)与优势特征蒸馏(PFD)对比; 在 MD 中,知识(Knowledge)是从更复杂的模型中蒸馏出来,而在 PFD 中,知识是从优势特征中蒸馏出来。

由此可见,我们可以训练一个使用了复杂特征(如交叉特征)的模型作为老师,指导训练一个仅使用简单特征的学生模型,从而实现提升模型效果,而又不增加线上耗时(线上使用交叉特征等复杂特征通常会导致耗时大幅增加,因此,在粗排环节几乎不直接使用交叉特征)。

3.基于 Wide&Deep 指导训练 TowTower 模型

基于 PFD 方法的原理,在本节我们将实现一个知识蒸馏的 Demo。其中,Teacher 模型基于 Wide&Deep 模型;Student 模型则采用简单的“双塔模型”。为了简单起见,Wide&Deep 模型和 “双塔模型” 均为单目标(CTR )模型。

3.1 模拟数据构造

"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)

3.2 特征工程

代码如下,相较于 Student 模型,作为 Teacher 的 Wide&Deep 模型采用了更多的特征,特别是交叉特征。

"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)""" 
交叉特征预处理 
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)""" 
特征列定义 
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]

3.3 模型架构设计

Teacher 模型:采用 Wide&Deep 模型(模拟精排模型);Student 模型:采用普通 “双塔模型”(模拟粗排模型)。

"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}

3.4 知识蒸馏实现

本质上就是用 Teacher 模型指导 Student 模型训练。使得 Student 模型的预测结果逼近 Teacher 模型。

"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False)  # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature,  # 教师输出软化student_ctr / self.temperature  # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}

3.5 模型训练与评估

  • 第一步:数据准备;
  • 第二步:模型初始化;
  • 第三步:编译、训练 Teacher 模型;
  • 第四步:编译、训练 Student 模型;
  • 第五步:评估、可视化效果
"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3]  # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

3.6 模型服务化与测试

保存训练好的学生模型,在另一个工程中可以加载这个模型,并执行预测。

"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")

3.7 知识蒸馏完整代码

完整代码如下:

import tensorflow as tftf.config.set_visible_devices([], 'GPU')  # 禁用GPU设备
from tensorflow import feature_column
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)""" 
交叉特征预处理 
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)""" 
特征列定义 
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False)  # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature,  # 教师输出软化student_ctr / self.temperature  # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3]  # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")

3.8 运行效果

Teacher 模型训练过程:

训练教师模型...
Epoch 1/5
2025-03-30 21:41:55.398982: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
25/25 [==============================] - 2s 13ms/step - loss: 0.5838 - accuracy: 0.5013 - val_loss: 0.5115 - val_accuracy: 0.4850
Epoch 2/5
25/25 [==============================] - 0s 3ms/step - loss: 0.5049 - accuracy: 0.5013 - val_loss: 0.5101 - val_accuracy: 0.4850
Epoch 3/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5037 - accuracy: 0.5013 - val_loss: 0.5093 - val_accuracy: 0.4850
Epoch 4/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5026 - accuracy: 0.5013 - val_loss: 0.5085 - val_accuracy: 0.4850
Epoch 5/5
25/25 [==============================] - 0s 5ms/step - loss: 0.5014 - accuracy: 0.5013 - val_loss: 0.5077 - val_accuracy: 0.4850

Student 模型训练过程:

训练学生模型...
Epoch 1/10
25/25 [==============================] - 2s 11ms/step - accuracy: 0.4975 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 2/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5038 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 3/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5063 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 4/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5050 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 5/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 6/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 7/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 8/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 9/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 10/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5088 - val_loss: 0.0000e+00 - val_accuracy: 0.4900

模型结构及预测示例:

Model: "tow_tower_student"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense_features_2 (DenseFeat  multiple                 23706     ures)                                                           dense_features_3 (DenseFeat  multiple                 1620      ures)                                                           sequential_1 (Sequential)   (None, 32)                4000      sequential_2 (Sequential)   (None, 32)                2976      dense_8 (Dense)             multiple                  2         =================================================================
Total params: 32,304
Trainable params: 32,304
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 178ms/step
预测结果:CTR=0.526

可视化训练过程:
在这里插入图片描述

3.9 模型预测

在另一个工程中加载通过蒸馏训练好的 Student 模型,并执行预测,代码示例如下:

# 导入必要的库
import tensorflow as tf
import pandas as pd
import numpy as np# 人工构造数据
num_users = 100
num_items = 200# 重新生成新的样本,模拟真实数据进行预测
def generate_new_samples(num_samples=5):new_samples = []for _ in range(num_samples):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)user_age = np.random.randint(18, 65)user_gender = np.random.choice(['male', 'female'])user_occupation = np.random.choice(['student', 'worker', 'teacher'])city_code = np.random.randint(1, 2856)device_type = np.random.randint(0, 5)item_category = np.random.choice(['electronics', 'books', 'clothing'])item_brand = np.random.choice(['brandA', 'brandB', 'brandC'])item_price = np.random.randint(1, 199)new_samples.append({'user_id': user_id,'user_age': user_age,'user_gender': user_gender,'user_occupation': user_occupation,'city_code': city_code,'device_type': device_type,'item_id': item_id,'item_category': item_category,'item_brand': item_brand,'item_price': item_price})return pd.DataFrame(new_samples)# 生成并打印预览新的样本数据
new_samples = generate_new_samples(num_samples=5)
# 设置display.max_columns为None,强制显示全部列:
pd.set_option('display.max_columns', None)
print("\nGenerated New Samples:\n", new_samples)# 准备输入数据
input_dict = {'user_id': tf.convert_to_tensor(new_samples['user_id'].values, dtype=tf.int64),'user_age': tf.convert_to_tensor(new_samples['user_age'].values, dtype=tf.int64),'user_gender': tf.convert_to_tensor(new_samples['user_gender'].values, dtype=tf.string),'user_occupation': tf.convert_to_tensor(new_samples['user_occupation'].values, dtype=tf.string),'city_code': tf.convert_to_tensor(new_samples['city_code'].values, dtype=tf.int64),'device_type': tf.convert_to_tensor(new_samples['device_type'].values, dtype=tf.int64),'item_id': tf.convert_to_tensor(new_samples['item_id'].values, dtype=tf.int64),'item_category': tf.convert_to_tensor(new_samples['item_category'].values, dtype=tf.string),'item_brand': tf.convert_to_tensor(new_samples['item_brand'].values, dtype=tf.string),'item_price': tf.convert_to_tensor(new_samples['item_price'].values, dtype=tf.int64)
}# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')
# 明确使用默认签名
predict_fn = loaded_model.signatures['serving_default']
predictions = predict_fn(**input_dict)# 提取并打印预测结果
# 预测结果是一个 CTCVR 综合分
predicted_ctr = predictions['ctr_logits'].numpy().flatten()
new_samples['ctr_prob'] = predicted_ctr
print("\nPrediction Results:")
for idx, row in new_samples.iterrows():print(f"Item ID: {row['item_id']} | CTR Final Score: {row['ctr_prob']:.4f}")

运行结果如下:

Generated New Samples:user_id  user_age user_gender user_occupation  city_code  device_type  \
0       34        49      female         teacher        843            0   
1       15        30      female         student        564            3   
2       26        37        male         teacher       2229            0   
3       31        35        male          worker       2494            0   
4       41        57      female         student       1668            3   item_id item_category item_brand  item_price  
0      147   electronics     brandA         127  
1      196      clothing     brandC         190  
2        1         books     brandA           1  
3      150      clothing     brandA           5  
4      128   electronics     brandA         156  
Metal device set to: Apple M1 ProPrediction Results:
Item ID: 147 | CTR Final Score: 0.5263
Item ID: 196 | CTR Final Score: 0.5263
Item ID: 1 | CTR Final Score: 0.5263
Item ID: 150 | CTR Final Score: 0.4793
Item ID: 128 | CTR Final Score: 0.5263

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/75100.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习四大核心架构:神经网络(NN)、卷积神经网络(CNN)、循环神经网络(RNN)与Transformer全概述

目录 📂 深度学习四大核心架构 🌰 知识点概述 🧠 核心区别对比表 ⚡ 生活化案例理解 🔑 选型指南 📂 深度学习四大核心架构 第一篇: 神经网络基础(NN) 🌰 知识点概述…

R语言对偏态换数据进行转换(对数、平方根、立方根)

我们进行研究的时候经常会遇见偏态数据,数据转换是统计分析和数据预处理中的一项基本技术。使用 R 时,了解如何正确转换数据有助于满足统计假设、标准化分布并提高分析的准确性。在 R 中实现和可视化最常见的数据转换:对数、平方根和立方根转…

第十四届蓝桥杯省赛电子类单片机学习记录(客观题)

01.一个8位的DAC转换器,供电电压为3.3V,参考电压2.4V,其ILSB产生的输出电压增量是(D)V。 A. 0.0129 B. 0.0047 C. 0.0064 D. 0.0094 解析: ILSB(最低有效位)的电压增量计算公式…

HarmonyOSNext_API16_媒体查询

媒体查询条件详解 媒体查询是响应式设计的核心工具,通过判断设备特征动态调整界面样式。其完整规则由媒体类型、逻辑操作符和媒体特征三部分组成,具体解析如下: 一、媒体查询语法结构 基本格式: [媒体类型] [逻辑操作符] (媒体特…

Python+拉普拉斯变换求解微分方程

引言 在数学和工程学中,微分方程广泛应用于描述动态系统的行为,如电路、电气控制系统、机械振动等。求解微分方程的一个常见方法是使用拉普拉斯变换,尤其是在涉及到初始条件时。今天,我们将通过 Python 演示如何使用拉普拉斯变换来求解微分方程,并帮助大家更好地理解这一…

【算法】手撕快速排序

快速排序的思想 任取一个元素作为枢轴,然后想办法把这个区间划分为两部分,小于等于枢轴的放左边,大于等于枢轴的放右边 然后递归处理左右区间,直到空或只剩一个 具体动画演示详见 数据结构合集 - 快速排序(算法过程, 效率分析…

《八大排序算法》

相关概念 排序:使一串记录,按照其中某个或某些关键字的大小,递增或递减的排列起来。稳定性:它描述了在排序过程中,相等元素的相对顺序是否保持不变。假设在待排序的序列中,有两个元素a和b,它们…

深度学习篇---paddleocr正则化提取

文章目录 前言一、代码总述&介绍1.1导入必要的库1.1.1cv21.1.2re1.1.3paddleocr 1.2初始化PaddleOCR1.3打开摄像头1.4使用 PaddleOCR 进行识别1.5定义正则表达式模式1.6打印提取结果1.7异常处理 二、正则表达式2.1简介2.2常用正则表达式模式及原理2.2.1. 快递单号模式2.2.2…

JavaScript DOM与元素操作

目录 DOM 树、DOM 对象、元素操作 一、DOM 树与 DOM 对象 二、获取 DOM 元素 1. 基础方法 2. 现代方法(ES6) 三、修改元素内容 四、修改元素常见属性 1. 标准属性 2. 通用方法 五、通过 style 修改样式 六、通过类名修改样式 1. className 属…

单元测试的编写

Python 单元测试示例 在 Python 中,通常使用 unittest 模块来编写单元测试。以下是一个简单的示例: 示例代码:calculator.py # calculator.py def add(a, b):return a bdef subtract(a, b):return a - b 单元测试代码:test_c…

大模型学习:从零到一实现一个BERT微调

目录 一、准备阶段 1.导入模块 2.指定使用的是GPU还是CPU 3.加载数据集 二、对数据添加词元和分词 1.根据BERT的预训练,我们要将一个句子的句头添加[CLS]句尾添加[SEP] 2.激活BERT词元分析器 3.填充句子为固定长度 代码解释: 三、数据处理 1.…

10组时尚复古美学自然冷色调肖像电影照片调色Lightroom预设 De La Mer – Nautical Lightroom Presets

De La Mer 预设系列包含 10 种真实的调色预设,适用于肖像、时尚和美术。为您的肖像摄影带来电影美学和个性! De La Mer 预设非常适合专业人士和业余爱好者,可在桌面或移动设备上使用,为您的摄影项目提供轻松的工作流程。这套包括…

SDL多窗口多线程渲染技术解析

SDL多窗口多线程渲染技术解析 技术原理 SDL多线程模型与窗口管理 SDL通过SDL_Thread结构体实现跨平台线程管理。在多窗口场景中,每个窗口需关联独立的渲染器,且建议遵循以下原则: 窗口与渲染器绑定:每个窗口创建时生成专属渲染器(SDL_CreateRenderer),避免跨线程操作…

QT 跨平台发布指南

一、Windows 平台发布 1. 使用 windeployqt 工具 windeployqt --release --no-compiler-runtime your_app.exe 2. 需要包含的文件 应用程序 .exe 文件 Qt5Core.dll, Qt5Gui.dll, Qt5Widgets.dll 等 Qt 库 platforms/qwindows.dll 插件 styles/qwindowsvistastyle.dll (如果使…

L2-037 包装机 (分数25)(详解)

题目链接——L2-037 包装机 问题分析 这个题目就是模拟了物品在传送带和筐之间的传送过程。传送带用队列模拟,筐用栈模拟。 输入 3 4 4 GPLT PATA OMSA 3 2 3 0 1 2 0 2 2 0 -1输出 根据上述操作,输出的物品顺序是: MATA样例分析 初始…

机器学习的一百个概念(4)下采样

前言 本文隶属于专栏《机器学习的一百个概念》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索&…

qt6下配置qopengl

qt部件选择 Qt 6:需要手动选择 Qt Shader Tools 和 Qt 5 Compatibility Module(如果需要兼容旧代码) cmake文件 cmake_minimum_required(VERSION 3.16) # Qt6 推荐最低 CMake 3.16 project(myself VERSION 0.1 LANGUAGES CXX)set(CMAKE_A…

数据安全系列4:密码技术的应用-接口调用的身份识别

传送门 数据安全系列1:开篇 数据安全系列2:单向散列函数概念 数据安全系列3:密码技术概述 什么是认证? 一谈到认证,多数人的反应可能就是"用户认证" 。就是应用系统如何识别用户的身份,直接…

STL之map和set

1. 关联式容器 vector、list、deque、 forward_list(C11)等,这些容器统称为序列式容器,因为其底层为线性序列的数据结构,里面存储的是元素本身。 关联式容器也是用来存储数据的,与序列式容器不同的是,其里面存储的是结…

Vue3 其它API Teleport 传送门

Vue3 其它API Teleport 传送门 在定义一个模态框时,父组件的filter属性会影响子组件的position属性,导致模态框定位错误使用Teleport解决这个问题把模态框代码传送到body标签下