推荐系统(十七):在TensorFlow中用户特征和商品特征是如何Embedding的?

在前面几篇关于推荐模型的文章中,笔者均给出了示例代码,有读者反馈——想知道在 TensorFlow 中用户特征和商品特征是如何 Embedding 的?因此,笔者特意写作此文加以解答。

1. 何为 Embedding ?

关于 Embedding,笔者很久之前写过一篇文章《推荐系统(十一):推荐系统中的 Embedding》,现在看来,差强人意,不过,对 Embedding 的概念解读还是不错的,只是缺乏代码案例解读。在本文中,笔者将基于 TensorFlow 来做解读,让读者加深理解。

如下图所示,为一个极简(CTR 和 CVR 共享了交互层) “双塔模型”(详见文章《推荐系统(十五):基于双塔模型的多目标商品召回/推荐系统》),简单解读一下:

  1. User Feature 和 Item Feature 先经过 Embedding Layer 处理,得到特征的 Embedding;
  2. User Feature Embedding 和 Item Feature Embedding 经过 Concat Layer 连接后输入到 DNN 网络;这样直接 Concat 得到的 Embedding 结果被称为 User 和 Item 的 “表示(Representation)”,显然,这种 “表示” 比较粗糙;
  3. 经过 MLP 处理,得到 User Vector 和 Item Vector,相较于上一步的 “表示形式”,User Vector 和 Item Vector 要 “精细” 得多,是真正意义上的 User Embedding 和 Item Embedding。
  4. User Embedding 和 Item Embedding 计算内积后经过 Sigmoid 函数处理(即图中的 Prediction),即可得到一个 0~1 之间的数值,即概率。
  5. 对于商品点击(1-点击,0-未点击)和商品转化(1-转化,0-未转化)这种二分类问题,结合模型预测的概率和样本 Label,很容易计算出损失(二分类问题一般采用交叉墒损失)。
  6. 对于 CTR 和 CVR 这种多任务场景,需要将 CTR Loss 和 CVR Loss 加权融合作为最终的损失,进而指导训练模型。
    在这里插入图片描述

2.特征工程中的 Embedding

2.1 ID 类特征

在 User Feature 和 Item Feature 中,User ID 和 Item ID 是最为重点的特征之一,是典型的 “高维稀疏” 特征。直接以原始数据形式输入模型是不行的,必须经过 Embedding Layer 的处理。在此,以 Item ID 为例,Embedding 处理的代码如下:

# 模拟生成商品特征,其中 item_id 取值[1, 10000]
num_items = 10000
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 基于 TensorFlow 对原始的 item_id 进行 Embedding 处理,分为两步
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)

1. 分类列的创建:categorical_column_with_identity

item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
  • 功能:将输入的整数 item_id 直接映射为分类标识。例如,若 num_items=1000,则输入的 item_id 必须是 [0,
    1, 2, …, 999] 范围内的整数。
  • 本质:这类似于对 item_id 做 One-Hot 编码(但底层实现更高效,不显式生成稀疏矩阵)。

2.嵌入列的创建:embedding_column

item_id_emb = feature_column.embedding_column(item_id, dimension=8)
  • 功能:将高维稀疏的分类 ID(如 num_items=1000 维的 One-Hot 向量)映射到低维稠密的连续向量空间(维度为 8)。
  • 关键点:嵌入矩阵的维度是 [num_items, 8],即每个 item_id 对应一个 8 维向量。这个嵌入矩阵是一个可训练参数,初始值随机(如 Glorot 初始化),通过神经网络的反向传播逐步优化。

3.嵌入向量的训练过程

  • 何时生成:嵌入矩阵的值并非预先计算,而是在模型训练时动态学习。
  • 如何学习
    1-输入数据中的 item_id 会触发嵌入层查找对应的 8 维向量。
    2-在反向传播时,优化器(如Adam)根据损失函数的梯度调整嵌入矩阵的值。
    3-模型通过最小化损失函数,迫使相似的 item_id 在嵌入空间中靠近,从而捕捉潜在语义关系(如用户行为中的物品相似性)。

4. 嵌入层的底层实现

当你在模型中调用 item_id_emb 时,TensorFlow 会隐式完成以下操作:

# 伪代码解释
embedding_matrix = tf.Variable(  # 可训练参数initial_value=tf.random.uniform([num_items, 8]), name="item_id_embedding"
)
# 根据输入的item_id查找嵌入向量
item_id_emb = tf.nn.embedding_lookup(embedding_matrix, input_item_ids)

5. 嵌入的优势

  • 降维:将高维稀疏特征压缩为低维稠密向量(例如从1000 维的 One-Hot 降到 8 维)。
  • 语义学习:模型自动学习嵌入空间中的几何关系(如相似物品的向量距离更近)。
  • 泛化性:即使某些 item_id 在训练数据中出现次数少,其嵌入向量仍可通过相似物品的梯度更新得到合理表示。

6. 完整流程示例

假设你的模型是一个推荐系统,处理流程如下:

  • 输入层:接收原始特征(如 {‘item_id’: 5})。
  • 特征转换:通过 item_id_emb 将 item_id=5 转换为一个 8 维向量。
  • 神经网络:将嵌入向量输入全连接层(如 DNN)、激活函数等后续结构。
  • 训练:通过损失函数(如点击率预测的交叉熵)反向传播,更新嵌入矩阵和其他权重。

2.2 类别特征

以用户性别为例:

# 模拟生成用户特征,其中用户性别是可以枚举的类别特征:male,female
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),  # 城市编码,中国有 2856 个城市'device_type': np.random.randint(0, 5, size=num_users)  # 设备类型(0=Android,1=iOS等)
}
# 对性别特征进行 Embedding 处理
user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)

1. 定义分类特征列

代码如下:

user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
  • 作用:将字符串类型的性别特征(如 ‘male’ 或 ‘female’)映射为整数索引。
  • 细节: 输入特征名为 ‘user_gender’,词汇表为 [‘male’, ‘female’]。 模型会根据词汇表将 ‘male’ 编码为
    0,‘female’ 编码为 1。 如果输入的值不在词汇表中(如 ‘unknown’),默认会被映射为 -1(可通过
    num_oov_buckets 参数调整)。

2. 创建嵌入列(Embedding Column)

如下代码:

user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)
  • 作用:将稀疏的整数索引转换为密集的低维向量(嵌入向量)。
  • 细节
    1.嵌入矩阵的维度:嵌入矩阵的形状为 (vocab_size, embedding_dimension),即 (2, 2)。 行数 2:对应词汇表中的两个类别(male 和 female);列数 2:指定的嵌入维度 dimension=2。
    2.嵌入初始化:嵌入向量的初始值默认通过随机均匀分布生成(可通过 initializer 参数自定义)。
    3.训练过程:嵌入向量会在模型训练时通过反向传播自动优化,学习与任务相关的语义表示。

2.3 数值特征

数值特征是一种简单的特征,按照常理,可以直接用原始数据进行模型训练和预测,然而,由于不同类型的数值特征存在 “量纲差异”,从而使得不同类型的数值特征 “不可比较”(如年龄数值区间(0~150),价格区间(0~10000000)),因此,数值特征也需要处理,比如标准化/归一化,好处如下:

  • 统一特征尺度,避免梯度下降因不同特征量纲而震荡。
  • 所有特征在相同尺度下,模型权重更新更均衡。
  • L1/L2正则化对所有特征施加相似强度的惩罚。

以用户年龄为例:

scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')

1. 数据标准化处理

代码如下:

scaler_age = StandardScaler()
  • 作用:创建一个标准化处理器,用于对数值型特征(如年龄)进行均值方差标准化(Z-Score标准化)。
  • 细节:StandardScaler 是 scikit-learn 库中的标准化工具,核心操作为:标准化值 =(原始值−均值)/ 标准差 标准化后,数据分布均值为 0,标准差为 1,消除量纲差异。适用于数值范围大、分布不均衡的特征(如年龄范围可能从 0 到 100)。

2.应用标准化到年龄列

代码如下:

df['user_age'] = scaler_age.fit_transform(df[['user_age']])
  • 作用:对 DataFrame 中的 user_age 列进行拟合和转换,实现标准化。
  • 细节
    1. fit_transform 两步合并。fit:计算 user_age 列的均值(μ)和标准差(σ)。transform:使用公式:(X−μ)/ σ,对所有样本进行标准化。
    2. 示例:假设原始年龄数据为 [20, 30, 40],均值为 30,标准差为 8.16,标准化后为 [-1.22, 0, 1.22]。
    3. 存储参数:scaler_age 对象会保存计算出的 μ 和 σ,便于后续对新数据(如测试集)使用 transform 而非重新拟合。

3.创建数值特征列

代码如下:

user_age = feature_column.numeric_column('user_age')
  • 作用:定义 TensorFlow 模型可接收的数值型特征列,将标准化后的年龄值直接输入模型。
  • 细节
    1. 输入数据类型:该列接收的是连续数值(如标准化后的 -1.22、0、1.22)。
    2. 模型中的处理:在训练时,每个样本的 user_age 值会以浮点数形式直接传递给神经网络,无需进一步编码。
    3. 参数扩展性: 可结合其他参数增强特征(例如 normalizer_fn 可添加自定义归一化,但此处已提前标准化,通常不再需要)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/73831.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++第三课(基础c)

1.前文 2.break 3.continue 4.return 0 1.前文 上次写文章到现在&#xff0c;有足足这么多天&#xff08;我也不知道&#xff0c;自己去数吧&#xff09; 开始吧 2.break break是结束循环的意思 举个栗子 #include<bits/stdc.h> using namespace std; int main(…

关于ArcGIS中加载影像数据,符号系统中渲染参数的解析

今天遇到一个很有意思的问题&#xff0c;故记录下来&#xff0c;以作参考和后续的研究。欢迎随时沟通交流。如果表达错误或误导&#xff0c;请各位指正。 正文 当我们拿到一幅成果影像数据的时候&#xff0c;在不同的GIS软件中会有不同效果呈现&#xff0c;但这其实是影像是…

北森测评的经验

测评经验记录 首先声明&#xff0c;北森测评就是垃圾&#xff0c;把行测拿过来就能评测能力了&#xff1f;直接去参加公务员考试更好。网上2024年的题库 评测分为 阅读理解数学计算图形题性格测试 图形题 总结的经验如下 图形推理题 一组图形&#xff0c;推测另一组图形最…

Java/Scala是什么

Java 和 Scala 是两种运行在 ​JVM&#xff08;Java 虚拟机&#xff09;​ 上的编程语言&#xff0c;虽然共享相同的运行时环境&#xff0c;但它们在设计哲学、语法特性和适用场景上有显著差异。以下是两者的详细解析&#xff1a; ​1. Java ​核心特性 ​面向对象&#xff1…

SQL Server 备份相关信息查看

目录标题 一、统计每个数据库在不同备份目录和备份类型下的备份次数&#xff0c;以及最后一次备份的时间整体功能详细解释 二、查询所有完整数据库备份的信息&#xff0c;包括备份集 ID、数据库名称、备份开始时间和备份文件的物理设备名称&#xff0c;并按备份开始时间降序排列…

CANoe入门——CANoe的诊断模块,调用CAPL进行uds诊断

目录 一、诊断窗口介绍 二、诊断数据库文件管理 三、添加基础诊断描述文件&#xff08;若没有CDD/ODX/PDX文件&#xff09;并使用对应的诊断功能进行UDS诊断 3.1、添加基础诊断描述文件 3.2、基于基础诊断&#xff0c;使用诊断控制台进行UDS诊断 3.2.1、生成基础诊断 3.…

【数据结构】二叉树的递归

数据结构系列三&#xff1a;二叉树(二) 一、递归的原理 1.全访问 2.主角 3.返回值 4.执等 二、递归的化关系思路 三、递归的方法设计 一、递归的原理 1.全访问 方法里调用方法自己&#xff0c;就会形成调用方法本身的一层一层全新相同的调用&#xff0c;方法的形参设置…

Imgui处理glfw的鼠标键盘的方法

在Imgui初始化时&#xff0c;会重新接手glfw的键盘鼠标事件。也就是遇到glfw的键盘鼠标事件时&#xff0c;imgui先会运行自己的处理过程&#xff0c;然后再去处理用户自己注册的glfw的键盘鼠标事件。 看imgui_impl_glfw.cpp源码的安装回调函数部分代码 void ImGui_ImplGlfw_In…

【LVS】负载均衡群集部署(DR模式)

部署前IP分配 DR服务器&#xff1a;192.168.166.101 vip&#xff1a;192.168.166.100 Web服务器1&#xff1a;192.168.166.104 vip&#xff1a;192.168.166.100 Web服务器2&#xff1a;192.168.166.107 vip&#xff1a;192.168.166.100 NFS服务器&#xff1a;192.168.166.108 …

C++Primer学习(14.1 基本概念)

当运算符作用于类类型的运算对象时&#xff0c;可以通过运算符重载重新定义该运算符的含义。明智地使用运算符重载能令我们的程序更易于编写和阅读。举个例子&#xff0c;因为在Sales_item类中定义了输入、输出和加法运算符&#xff0c;所以可以通过下述形式输出两个Sales_item…

计算机视觉准备八股中

一边记录一边看&#xff0c;这段实习跑路之前运行完3DGAN&#xff0c;弄完润了&#xff0c;现在开始记忆八股 1.CLIP模型的主要创新点&#xff1a; 图像和文本两种不同模态数据之间的深度融合、对比学习、自监督学习 2.等效步长是每一步操作步长的乘积 3.卷积层计算输入输出…

基于大语言模型的智能音乐创作系统——从推荐到生成

一、引言&#xff1a;当AI成为音乐创作伙伴 2023年&#xff0c;一款由大语言模型&#xff08;LLM&#xff09;生成的钢琴曲《量子交响曲》在Spotify冲上热搜&#xff0c;引发音乐界震动。传统音乐创作需要数年专业训练&#xff0c;而现代AI技术正在打破这一壁垒。本文提出一种…

Mysql---锁篇

1&#xff1a;MySQL 有哪些锁&#xff1f; 全局锁 flush tables with read lock 整个数据库就处于只读状态了 unlock tables 释放全局锁 全局锁主要应用于做全库逻辑备份&#xff0c;这样在备份数据库期间&#xff0c;不会因为数据或表结构的更新&#xff0c;而出现备份文件的数…

VLAN综合实验二

一.实验拓扑&#xff1a; 二.实验需求&#xff1a; 1.内网Ip地址使用172.16.0.0/分配 2.sw1和SW2之间互为备份 3.VRRP/STP/VLAN/Eth-trunk均使用 4.所有Pc均通过DHCP获取IP地址 5.ISP只能配置IP地址 6.所有…

GEO(生成引擎优化)实施策略全解析:从用户意图到效果追踪

——基于行业实证的AI信源占位方法论 ​一、理解用户查询&#xff1a;构建AI语料的核心起点 生成式AI的内容推荐逻辑以用户意图为核心&#xff0c;​精准捕捉高频问题是GEO优化的第一步。企业需通过以下方法挖掘用户真实需求&#xff1a; ​AI对话日志分析&#xff1a; 分析用…

HTML基础及进阶

目录 一、HTML基础 1.什么是HTML 2.常用标签 &#xff08;1&#xff09;标题标签&#xff1a;h1-h6数字越小文字会越大&#xff0c;这个标签会占一整行 &#xff08;2&#xff09;加粗标签&#xff1a; &#xff08;3&#xff09;换行标签&#xff1a; &#xff08;4&am…

MSTP与链路聚合技术

MSTP&#xff08;多生成树协议&#xff09; 简介 MSTP&#xff08;多生成树协议&#xff09;是Spanning Tree Protocol&#xff08;STP&#xff09;的改进版&#xff0c;支持网络中使用多条生成树&#xff0c;并根据用户需求限制生成树间的路径。MSTP将多个VLAN映射到一棵生成…

ModuleNotFoundError: No module named ‘ml_logger.logbook‘

问题 (legion) zhouy24RL-DSlab:~/zhouy24Files/legion/LEGION$ python main.py ML_LOGGER_USER is not set. This is required for online usage. Traceback (most recent call last): File “main.py”, line 7, in from mtrl.app.run import run File “/data/zhouy24File…

c# ftp上传下载 帮助类

工作中FTP的上传和下载还是很常用的。如下载打标数据,上传打标结果等。 这个类常用方法都有了:上传,下载,判断文件夹是否存在,创建文件夹,获取当前目录下文件列表(不包括文件夹) ,获取当前目录下文件列表(不包括文件夹) ,获取FTP文件列表(包括文件夹), 获取当前目…

PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练&#xff08;Distributed Data Parallel, DDP&#xff09; 一、DDP 核心概念 torch.nn.parallel.DistributedDataParallel 1. DDP 是什么&#xff1f; Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口&#xff0c;DistributedDataPara…