Gorse 协同过滤模型训练详解

news/2026/1/19 11:23:51/文章来源:https://www.cnblogs.com/wanber/p/19500774

目录

  1. 协同过滤概述
  2. 矩阵分解原理
  3. 两大算法对比
  4. BPR 算法详解
  5. ALS 算法详解
  6. 训练流程
  7. 超参数调优
  8. 模型评估
  9. 实战示例

协同过滤概述

什么是协同过滤?

协同过滤(Collaborative Filtering) 是推荐系统中最经典的算法,通过分析"用户-物品"交互历史来预测用户对未交互物品的偏好。

核心思想:

"物以类聚,人以群分"
- 相似的用户喜欢相似的物品
- 喜欢某物品的用户,也会喜欢类似物品

为什么需要协同过滤?

问题场景:
- 用户 A 看过电影 1, 2, 3
- 用户 B 看过电影 1, 2, 4
- 用户 A 和 B 兴趣相似(都喜欢 1, 2)
- 那么可以推荐:- 给 A 推荐电影 4(B 看过)- 给 B 推荐电影 3(A 看过)

Gorse 中的协同过滤

Gorse 使用矩阵分解(Matrix Factorization)实现协同过滤,支持两种算法:

算法 全称 适用场景 训练方式
BPR Bayesian Personalized Ranking 隐式反馈(点击、浏览) SGD(随机梯度下降)
ALS Alternating Least Squares 隐式反馈 + 评分预测 交替最小二乘

矩阵分解原理

用户-物品交互矩阵

假设有 3 个用户和 4 个物品:

          物品1  物品2  物品3  物品4
用户A      1      1      0      ?
用户B      1      0      1      ?
用户C      0      1      1      ?
  • 1 表示有交互(点击、购买等)
  • 0 表示无交互
  • ? 表示需要预测的

问题:这个矩阵非常稀疏(大部分是 0),如何填充 ?

矩阵分解的核心思想

稀疏的交互矩阵分解为两个低维稠密矩阵

用户-物品矩阵 (M×N) = 用户向量矩阵 (M×K) × 物品向量矩阵 (K×N)M: 用户数
N: 物品数
K: 隐向量维度(通常 10-100)

数学表示

R ≈ P × Q^TR[u,i] ≈ P[u] · Q[i]  (向量点积)
  • R: 用户-物品交互矩阵
  • P[u]: 用户 u 的隐向量(K 维)
  • Q[i]: 物品 i 的隐向量(K 维)

直观理解

假设 K=2(两个隐因子):

用户 A 的向量: [0.8, 0.3]  → 80% 喜欢"动作片",30% 喜欢"爱情片"
物品 1 的向量: [0.9, 0.1]  → 90% 是"动作片",10% 是"爱情片"预测评分: 0.8 × 0.9 + 0.3 × 0.1 = 0.72 + 0.03 = 0.75 (高分,推荐!)

为什么要降维?

原始矩阵 分解后
100万用户 × 10万物品 = 1000亿个元素 100万×16 + 16×10万 = 1760万个元素
99% 是 0(稀疏) 100% 有值(稠密)
无法泛化(过拟合) 能够泛化(学到模式)

两大算法对比

BPR vs ALS

维度 BPR ALS
全称 Bayesian Personalized Ranking Alternating Least Squares
优化目标 最大化正样本排名高于负样本 最小化预测误差
损失函数 Pairwise Loss(成对损失) Pointwise Loss(逐点损失)
训练方式 SGD(随机梯度下降) 交替最小二乘(封闭解)
适用场景 排序任务(Top-N 推荐) 评分预测 + 排序
训练速度 较慢(每次一个样本) 较快(并行优化)
内存占用 较小 较大
默认选择 ✅ Gorse 默认 可选

何时选择 BPR?

✅ 适合:
- 只有隐式反馈(点击、浏览、购买)
- 关注排序(Top-N 推荐)
- 数据量大,需要在线训练❌ 不适合:
- 需要评分预测(用 ALS)
- 数据量小(用简单的 ItemCF)

何时选择 ALS?

✅ 适合:
- 需要评分预测
- 数据量大,训练时间充足
- 可以并行化训练❌ 不适合:
- 纯排序任务(BPR 更好)
- 内存受限(ALS 需要更多内存)

BPR 算法详解

核心思想

BPR 不预测评分,而是学习排序关系

对于用户 u:
- 物品 i (有交互)应该排在 物品 j (无交互)之前
- p(i >_u j) = σ(P_u · (Q_i - Q_j))

损失函数

# 伪代码
for user in users:positive_item = random.choice(user.interacted_items)negative_item = random.choice(all_items - user.interacted_items)score_pos = dot(user_vector, item_vector[positive_item])score_neg = dot(user_vector, item_vector[negative_item])loss = -log(sigmoid(score_pos - score_neg))

数学表示

L = -Σ log(σ(r̂_ui - r̂_uj)) + λ(||P_u||^2 + ||Q_i||^2 + ||Q_j||^2)σ(x) = 1 / (1 + e^(-x))  (Sigmoid 函数)
λ: 正则化系数

BPR 训练步骤

让我们看 Gorse 中的实现:

func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, config *FitConfig) Score {// 1. 初始化用户和物品向量(随机初始化)bpr.Init(trainSet)// 2. 迭代训练for epoch := 1; epoch <= bpr.nEpochs; epoch++ {// 并行训练(每个 worker 处理一批样本)parallel.Parallel(trainSet.CountFeedback(), config.Jobs, func(workerId, _ int) error {// 2.1 随机选择一个用户(有交互历史的)userIndex := random.choice(users with feedback)// 2.2 随机选择一个正样本(用户交互过的物品)posIndex := random.choice(user's interacted items)// 2.3 随机选择一个负样本(用户未交互的物品)negIndex := random.choice(items - user's interacted items)// 2.4 计算预测分数差diff = predict(user, posIndex) - predict(user, negIndex)// diff = P_u · Q_i - P_u · Q_j// 2.5 计算损失loss = log(1 + exp(-diff))// 2.6 计算梯度grad = exp(-diff) / (1 + exp(-diff))// 2.7 更新参数(梯度下降)// ∂L/∂Q_i = -grad * P_u - reg * Q_i// ∂L/∂Q_j = grad * P_u - reg * Q_j  // ∂L/∂P_u = -grad * (Q_i - Q_j) - reg * P_uQ_i += lr * (grad * P_u - reg * Q_i)Q_j += lr * (-grad * P_u - reg * Q_j)P_u += lr * (grad * (Q_i - Q_j) - reg * P_u)return nil})// 2.8 定期评估(每 10 个 epoch)if epoch % config.Verbose == 0 {score = Evaluate(bpr, valSet, trainSet, topK, ...)log.Info("NDCG@10", score.NDCG)}}return score
}

BPR 代码解析

1. 初始化阶段

func (bpr *BPR) Init(trainSet dataset.CFSplit) {// 随机初始化用户向量(均值 0,标准差 0.001)bpr.UserFactor = randomNormalMatrix(numUsers, nFactors,      // 隐向量维度(默认 16)initMean,      // 均值(默认 0)initStdDev,    // 标准差(默认 0.001))// 随机初始化物品向量bpr.ItemFactor = randomNormalMatrix(numItems, nFactors, initMean, initStdDev,)// 标记哪些用户/物品有训练数据for userIndex, feedback := range trainSet.GetUserFeedback() {if len(feedback) > 0 {bpr.UserPredictable.Set(userIndex)}}
}

2. 核心训练循环

// 2.1 选择用户
var userIndex int32
for {userIndex = random.Int31n(trainSet.CountUsers())if len(trainSet.GetUserFeedback()[userIndex]) > 0 {break  // 确保用户有交互历史}
}// 2.2 选择正样本
ratingCount := len(trainSet.GetUserFeedback()[userIndex])
posIndex := trainSet.GetUserFeedback()[userIndex][random.Intn(ratingCount)]// 2.3 选择负样本(不在用户交互集中)
var negIndex int32
for {temp := random.Int31n(trainSet.CountItems())if !userFeedback[userIndex].Contains(temp) {negIndex = tempbreak}
}

3. 梯度计算和参数更新

// 预测分数差
diff := bpr.internalPredict(userIndex, posIndex) - bpr.internalPredict(userIndex, negIndex)
// diff = P_u · Q_i - P_u · Q_j// 损失
cost[workerId] += log1p(exp(-diff))// Sigmoid 梯度
grad := exp(-diff) / (1.0 + exp(-diff))// 备份当前参数
copy(userFactor[workerId], bpr.UserFactor[userIndex])
copy(positiveItemFactor[workerId], bpr.ItemFactor[posIndex])
copy(negativeItemFactor[workerId], bpr.ItemFactor[negIndex])// 更新正样本物品向量
// Q_i = Q_i + lr * (grad * P_u - reg * Q_i)
temp = grad * userFactor
temp = temp - reg * positiveItemFactor
bpr.ItemFactor[posIndex] += lr * temp// 更新负样本物品向量
// Q_j = Q_j + lr * (-grad * P_u - reg * Q_j)
temp = -grad * userFactor
temp = temp - reg * negativeItemFactor
bpr.ItemFactor[negIndex] += lr * temp// 更新用户向量
// P_u = P_u + lr * (grad * (Q_i - Q_j) - reg * P_u)
temp = positiveItemFactor - negativeItemFactor
temp = grad * temp
temp = temp - reg * userFactor
bpr.UserFactor[userIndex] += lr * temp

BPR 超参数

type BPR struct {nFactors   int      // 隐向量维度(默认 16)nEpochs    int      // 训练轮数(默认 100)lr         float32  // 学习率(默认 0.05)reg        float32  // 正则化系数(默认 0.01)initMean   float32  // 初始化均值(默认 0)initStdDev float32  // 初始化标准差(默认 0.001)
}

调优建议

参数 默认值 调优范围 说明
n_factors 16 8-128 越大越精确,但训练慢
n_epochs 100 50-500 根据收敛情况调整
lr 0.05 0.001-0.1 过大不收敛,过小收敛慢
reg 0.01 0.001-0.1 防止过拟合
init_std_dev 0.001 0.0001-0.01 影响训练稳定性

ALS 算法详解

核心思想

ALS 通过交替固定用户向量或物品向量,优化另一组向量:

迭代 1: 固定物品向量 Q,优化用户向量 P
迭代 2: 固定用户向量 P,优化物品向量 Q
迭代 3: 固定物品向量 Q,优化用户向量 P
...

每一步都有封闭解(Closed-form Solution),无需梯度下降!

损失函数

L = Σ c_ui (r_ui - P_u · Q_i)^2 + λ(||P||^2 + ||Q||^2)c_ui: 置信度(1 + α * r_ui)
α: 权重参数(默认 0.001)
λ: 正则化系数(默认 0.06)

ALS 优化公式

对于用户 u,固定 Q 后,P_u 的最优解:

P_u = (Q^T C^u Q + λI)^(-1) Q^T C^u r_uC^u: 对角矩阵,对角线是 c_ui
r_u: 用户 u 的评分向量

ALS 训练步骤

func (als *ALS) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, config *FitConfig) Score {// 1. 初始化als.Init(trainSet)for epoch := 1; epoch <= als.nEpochs; epoch++ {// ==================== 第一步:更新用户向量 ====================// 1.1 计算 S^q = Σ Q_i Q_i^T(物品向量的格拉姆矩阵)S_q := zeros(nFactors, nFactors)for itemIndex := 0; itemIndex < trainSet.CountItems(); itemIndex++ {if len(trainSet.GetItemFeedback()[itemIndex]) > 0 {for i := 0; i < nFactors; i++ {for j := 0; j < nFactors; j++ {S_q[i][j] += Q[itemIndex][i] * Q[itemIndex][j]}}}}// 1.2 并行更新每个用户parallel.Parallel(trainSet.CountUsers(), config.Jobs, func(workerId, userIndex int) {userFeedback := trainSet.GetUserFeedback()[userIndex]// 对每个隐因子 ffor f := 0; f < nFactors; f++ {// 计算用户向量的第 f 维a, b, c := 0.0, 0.0, 0.0for _, itemIndex := range userFeedback {// a = Σ (1 - (1-α) * residual) * Q[i][f]a += (1 - (1-weight)*residual[itemIndex]) * Q[itemIndex][f]// c = Σ (1-α) * Q[i][f]^2c += (1 - weight) * Q[itemIndex][f] * Q[itemIndex][f]}// b = α * Σ P_u[k] * S_q[k][f]  (k≠f)for k := 0; k < nFactors; k++ {if k != f {b += weight * P[userIndex][k] * S_q[k][f]}}// 更新P[userIndex][f] = (a - b) / (c + weight*S_q[f][f] + reg)}})// ==================== 第二步:更新物品向量 ====================// 2.1 计算 S^p = Σ P_u P_u^T(用户向量的格拉姆矩阵)S_p := zeros(nFactors, nFactors)for userIndex := 0; userIndex < trainSet.CountUsers(); userIndex++ {if len(trainSet.GetUserFeedback()[userIndex]) > 0 {for i := 0; i < nFactors; i++ {for j := 0; j < nFactors; j++ {S_p[i][j] += P[userIndex][i] * P[userIndex][j]}}}}// 2.2 并行更新每个物品(类似用户更新)parallel.Parallel(trainSet.CountItems(), config.Jobs, func(workerId, itemIndex int) {// ... 类似用户更新逻辑})// 2.3 评估if epoch % config.Verbose == 0 {score = Evaluate(als, valSet, trainSet, topK, ...)log.Info("NDCG@10", score.NDCG)}}return score
}

ALS 代码解析

格拉姆矩阵(Gram Matrix)

// S^q = Σ q_i q_i^T
// 这是一个 K×K 的对称矩阵
floats.MatZero(s)
for itemIndex := 0; itemIndex < trainSet.CountItems(); itemIndex++ {if len(trainSet.GetItemFeedback()[itemIndex]) > 0 {for i := 0; i < als.nFactors; i++ {for j := 0; j < als.nFactors; j++ {s[i][j] += als.ItemFactor[itemIndex][i] * als.ItemFactor[itemIndex][j]}}}
}

这是所有物品向量的协方差矩阵,用于快速计算用户向量更新。

逐维优化

for f := 0; f < als.nFactors; f++ {// 计算残差(去掉当前维度的影响)for _, i := range userFeedback {userRes[workerId][i] = userPredictions[workerId][i] - als.UserFactor[userIndex][f] * als.ItemFactor[i][f]}// 计算 a, b, ca, b, c := float32(0), float32(0), float32(0)for _, i := range userFeedback {a += (1 - (1-als.weight)*userRes[workerId][i]) * als.ItemFactor[i][f]c += (1 - als.weight) * als.ItemFactor[i][f] * als.ItemFactor[i][f]}for k := 0; k < als.nFactors; k++ {if k != f {b += als.weight * als.UserFactor[userIndex][k] * s[k][f]}}// 更新第 f 维als.UserFactor[userIndex][f] = (a - b) / (c + als.weight*s[f][f] + als.reg)// 更新预测(加回当前维度的影响)for _, i := range userFeedback {userPredictions[workerId][i] = userRes[workerId][i] + als.UserFactor[userIndex][f] * als.ItemFactor[i][f]}
}

ALS 超参数

type ALS struct {nFactors   int      // 隐向量维度(默认 16)nEpochs    int      // 训练轮数(默认 50)reg        float32  // 正则化系数(默认 0.06)initMean   float32  // 初始化均值(默认 0)initStdDev float32  // 初始化标准差(默认 0.1)weight     float32  // 负样本权重(默认 0.001)
}

调优建议

参数 默认值 调优范围 说明
n_factors 16 8-128 越大越精确
n_epochs 50 20-200 ALS 收敛较快
reg 0.06 0.001-0.1 防止过拟合
weight (α) 0.001 0.0001-0.01 负样本权重
init_std_dev 0.1 0.01-0.5 比 BPR 大

训练流程

Master 中的训练流程

func (m *Master) trainCollaborativeFiltering(trainSet, testSet dataset.CFSplit) error {// 1. 数据检查if trainSet.CountUsers() == 0 || trainSet.CountItems() == 0 || trainSet.CountFeedback() == 0 {return errors.New("No data found")}// 2. 检查数据是否变化if trainSet.CountFeedback() == m.collaborativeFilteringTrainSetSize {log.Info("collaborative filtering dataset not changed")return nil  // 数据未变化,跳过训练}// 3. 选择模型(BPR 或 ALS)m.collaborativeFilteringModelMutex.Lock()modelType := m.collaborativeFilteringMeta.Type  // "BPR" or "ALS"params := m.collaborativeFilteringMeta.Params// 如果超参数优化找到了更好的模型,使用新参数if m.collaborativeFilteringTarget.Score.NDCG > m.collaborativeFilteringMeta.Score.NDCG {modelType = m.collaborativeFilteringTarget.Typeparams = m.collaborativeFilteringTarget.Paramslog.Info("find better collaborative filtering model", zap.Any("score", m.collaborativeFilteringTarget.Score))}m.collaborativeFilteringModelMutex.Unlock()// 4. 创建模型并训练model := m.newCollaborativeFilteringModel(modelType, params)score := model.Fit(ctx, trainSet, testSet,cf.NewFitConfig().SetJobs(m.Config.Master.NumJobs).SetPatience(m.Config.Recommend.Collaborative.EarlyStopping.Patience))log.Info("fit collaborative filtering model completed",zap.Float32("NDCG@10", score.NDCG),zap.Float32("Recall@10", score.Recall),zap.Float32("Precision@10", score.Precision))// 5. 构建物品向量索引(HNSW)matrixFactorizationItems := logics.NewMatrixFactorizationItems(time.Now())parallel.For(trainSet.CountItems(), m.Config.Master.NumJobs, func(i int) {if itemId, ok := trainSet.GetItemDict().String(int32(i)); ok && model.IsItemPredictable(int32(i)) {matrixFactorizationItems.Add(itemId, model.GetItemFactor(int32(i)))}})// 6. 提取用户向量matrixFactorizationUsers := logics.NewMatrixFactorizationUsers()for i := 0; i < trainSet.CountUsers(); i++ {if userId, ok := trainSet.GetUserDict().String(int32(i)); ok && model.IsUserPredictable(int32(i)) {matrixFactorizationUsers.Add(userId, model.GetUserFactor(int32(i)))}}// 7. 上传模型到 Blob StoremodelId := time.Now().Unix()w, done, err := m.blobStore.Create(strconv.Itoa(modelId))err = matrixFactorizationItems.Marshal(w)   // 保存物品向量 + HNSW 索引err = matrixFactorizationUsers.Marshal(w)   // 保存用户向量w.Close()<-done// 8. 更新元数据m.collaborativeFilteringMeta.ID = modelIdm.collaborativeFilteringMeta.Type = modelTypem.collaborativeFilteringMeta.Params = paramsm.collaborativeFilteringMeta.Score = scoreerr = m.metaStore.Put(meta.COLLABORATIVE_FILTERING_MODEL, m.collaborativeFilteringMeta.ToJSON())// 9. 更新 Prometheus 指标CollaborativeFilteringNDCG10.Set(float64(score.NDCG))CollaborativeFilteringRecall10.Set(float64(score.Recall))CollaborativeFilteringPrecision10.Set(float64(score.Precision))return nil
}

训练流程图

开始训练协同过滤模型↓
检查数据(用户、物品、反馈)↓
数据未变化? → Yes → 跳过训练↓ No
选择模型类型(BPR or ALS)↓
检查是否有更好的超参数↓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━模型训练(Fit)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━↓
初始化用户/物品向量↓
┌──────────────────────────┐
│  迭代训练(100 轮)        │
│  ├─ 采样三元组 (u,i,j)    │
│  ├─ 计算损失              │
│  ├─ 更新梯度              │
│  └─ 定期评估 (NDCG)       │
└──────────────────────────┘↓
━━━━━━━━━━━━━━━━━━━━━━━━━━━━↓
构建 HNSW 索引(物品向量)↓
提取用户向量↓
上传到 Blob Store├─ matrixFactorizationItems (物品向量 + HNSW)└─ matrixFactorizationUsers (用户向量)↓
更新元数据(模型 ID、类型、参数、分数)↓
更新 Prometheus 指标↓
完成

超参数调优

自动超参数优化

Gorse 使用 Optuna 进行自动超参数搜索(贝叶斯优化):

func (m *Master) optimizeCollaborativeFiltering(trainSet, testSet dataset.CFSplit) error {// 1. 检查是否需要优化if m.Config.Recommend.Collaborative.OptimizeTrials <= 0 {return nil  // 未启用优化}// 2. 创建优化器study, err := goptuna.CreateStudy("optimizeCollaborativeFiltering",goptuna.StudyOptionDirection(goptuna.StudyDirectionMaximize),  // 最大化 NDCGgoptuna.StudyOptionSampler(tpe.NewSampler()),                  // TPE 采样器goptuna.StudyOptionLogger(log.NewOptunaLogger(log.Logger())))// 3. 优化目标函数objective := func(trial goptuna.Trial) (float64, error) {// 3.1 建议超参数modelType, _ := trial.SuggestCategorical("model_type", []string{"BPR", "ALS"})var params model.Paramsif modelType == "BPR" {params = cf.NewBPR(nil).SuggestParams(trial)// 自动建议:// - lr: [0.001, 0.1] (对数分布)// - reg: [0.001, 0.1] (对数分布)// - init_std_dev: [0.001, 0.1] (对数分布)} else {params = cf.NewALS(nil).SuggestParams(trial)// 自动建议:// - reg: [0.001, 0.1]// - weight: [0.0001, 0.01]// - init_std_dev: [0.001, 0.1]}// 3.2 训练模型m := m.newCollaborativeFilteringModel(modelType, params)score := m.Fit(ctx, trainSet, testSet, cf.NewFitConfig().SetJobs(m.Config.Master.NumJobs).SetPatience(5))  // 早停// 3.3 返回 NDCG 作为优化目标return float64(score.NDCG), nil}// 4. 运行优化(20 次试验)err = study.Optimize(objective, m.Config.Recommend.Collaborative.OptimizeTrials)// 5. 获取最佳参数bestParams := study.BestParams()bestScore := study.BestValue()// 6. 更新目标模型m.collaborativeFilteringModelMutex.Lock()m.collaborativeFilteringTarget.Type = modelTypem.collaborativeFilteringTarget.Params = bestParamsm.collaborativeFilteringTarget.Score.NDCG = float32(bestScore)m.collaborativeFilteringModelMutex.Unlock()log.Info("found best collaborative filtering model",zap.String("type", modelType),zap.Any("params", bestParams),zap.Float64("NDCG", bestScore))return nil
}

超参数搜索空间

BPR

func (bpr *BPR) SuggestParams(trial goptuna.Trial) model.Params {return model.Params{model.NFactors:   16,  // 固定为 16model.Lr:         trial.SuggestLogFloat("Lr", 0.001, 0.1),model.Reg:        trial.SuggestLogFloat("Reg", 0.001, 0.1),model.InitMean:   0,   // 固定为 0model.InitStdDev: trial.SuggestLogFloat("InitStdDev", 0.001, 0.1),}
}

ALS

func (als *ALS) SuggestParams(trial goptuna.Trial) model.Params {return model.Params{model.NFactors:   16,  // 固定为 16model.InitMean:   0,model.InitStdDev: trial.SuggestLogFloat("InitStdDev", 0.001, 0.1),model.Reg:        trial.SuggestLogFloat("Reg", 0.001, 0.1),model.Alpha:      trial.SuggestLogFloat("Alpha", 0.001, 0.1),  // weight}
}

配置超参数优化

[recommend.collaborative]
# 启用超参数优化(20 次试验)
optimize_trials = 20# 早停策略(5 轮不改进则停止)
[recommend.collaborative.early_stopping]
patience = 5# 训练周期(每 3 小时训练一次)
fit_period = "3h"

模型评估

评估指标

Gorse 使用三个指标评估协同过滤模型:

指标 全称 说明 取值范围
NDCG Normalized Discounted Cumulative Gain 考虑排序位置的归一化增益 [0, 1]
Precision Precision@K 前 K 个推荐中相关物品的比例 [0, 1]
Recall Recall@K 所有相关物品中被推荐的比例 [0, 1]

NDCG@K 计算

func NDCG(targetSet mapset.Set[int32], rankList []int32, k int) float32 {if len(rankList) > k {rankList = rankList[:k]}// DCG = Σ (rel_i / log2(i+1))dcg := float32(0)for i, itemIndex := range rankList {if targetSet.Contains(itemIndex) {dcg += 1.0 / math32.Log2(float32(i+2))  // i+2 because i starts from 0}}// IDCG = 理想情况下的 DCG(所有相关物品排在最前面)idcg := float32(0)for i := 0; i < min(targetSet.Cardinality(), k); i++ {idcg += 1.0 / math32.Log2(float32(i+2))}if idcg == 0 {return 0}// NDCG = DCG / IDCGreturn dcg / idcg
}

示例

假设用户真实喜欢的物品是 {1, 3, 5}
模型推荐 Top-5: [3, 2, 1, 4, 5]DCG = 1/log2(2) + 0 + 1/log2(4) + 0 + 1/log2(6)= 1/1 + 0 + 1/2 + 0 + 1/2.58= 1 + 0.5 + 0.39= 1.89IDCG = 1/log2(2) + 1/log2(3) + 1/log2(4)= 1 + 0.63 + 0.5= 2.13NDCG = 1.89 / 2.13 = 0.887

Precision@K 和 Recall@K

func Precision(targetSet mapset.Set[int32], rankList []int32, k int) float32 {if len(rankList) > k {rankList = rankList[:k]}hit := 0for _, itemIndex := range rankList {if targetSet.Contains(itemIndex) {hit++}}// Precision = 命中数 / Kreturn float32(hit) / float32(len(rankList))
}func Recall(targetSet mapset.Set[int32], rankList []int32, k int) float32 {if len(rankList) > k {rankList = rankList[:k]}hit := 0for _, itemIndex := range rankList {if targetSet.Contains(itemIndex) {hit++}}// Recall = 命中数 / 真实相关物品数return float32(hit) / float32(targetSet.Cardinality())
}

示例

用户真实喜欢: {1, 3, 5}  (3 个物品)
推荐 Top-5: [3, 2, 1, 4, 5]
命中: {1, 3, 5}  (3 个)Precision@5 = 3 / 5 = 0.6
Recall@5 = 3 / 3 = 1.0

评估流程

func Evaluate(model cf.Model, testSet, trainSet dataset.CFSplit, k, candidates, jobs int) []float32 {// 1. 为每个测试用户生成推荐scores := make([][]float32, 3)  // NDCG, Precision, Recallparallel.Parallel(testSet.CountUsers(), jobs, func(workerId, userIndex int) {// 1.1 获取测试集中的真实交互targetSet := testSet.GetUserFeedback()[userIndex]if len(targetSet) == 0 {return  // 跳过无测试数据的用户}// 1.2 生成推荐列表(从候选集中)rankList := make([]int32, 0, candidates)for itemIndex := 0; itemIndex < trainSet.CountItems(); itemIndex++ {// 跳过训练集中已交互的物品if !trainSet.GetUserFeedback()[userIndex].Contains(itemIndex) {score := model.internalPredict(userIndex, itemIndex)rankList = append(rankList, itemIndex, score)}}// 1.3 排序(按分数降序)sort.Slice(rankList, func(i, j int) bool {return rankList[i].score > rankList[j].score})// 1.4 计算指标scores[0] += NDCG(targetSet, rankList, k)scores[1] += Precision(targetSet, rankList, k)scores[2] += Recall(targetSet, rankList, k)})// 2. 平均numUsers := testSet.CountUsersWithFeedback()return []float32{scores[0] / float32(numUsers),scores[1] / float32(numUsers),scores[2] / float32(numUsers),}
}

实战示例

示例 1:手动训练 BPR 模型

package mainimport ("context""github.com/gorse-io/gorse/dataset""github.com/gorse-io/gorse/model""github.com/gorse-io/gorse/model/cf"
)func main() {// 1. 加载数据data := dataset.NewDataset(time.Now(), 1000, 5000)// 添加用户-物品交互...trainSet, testSet := data.SplitCF(0.2, 0)// 2. 创建 BPR 模型params := model.Params{model.NFactors:   16,model.Lr:         0.05,model.Reg:        0.01,model.NEpochs:    100,model.InitMean:   0,model.InitStdDev: 0.001,}bpr := cf.NewBPR(params)// 3. 训练config := cf.NewFitConfig().SetJobs(4).SetVerbose(10).SetPatience(5)score := bpr.Fit(context.Background(), trainSet, testSet, config)// 4. 查看结果fmt.Printf("NDCG@10: %.4f\n", score.NDCG)fmt.Printf("Precision@10: %.4f\n", score.Precision)fmt.Printf("Recall@10: %.4f\n", score.Recall)// 5. 预测userVector := bpr.GetUserFactor(0)  // 用户 0 的向量itemVector := bpr.GetItemFactor(10) // 物品 10 的向量score := floats.Dot(userVector, itemVector)fmt.Printf("预测评分: %.4f\n", score)
}

示例 2:使用配置文件训练

# config.toml[recommend.collaborative]
# 模型类型(BPR 或 ALS)
model = "BPR"# 训练周期
fit_period = "3h"# 超参数
[recommend.collaborative.params]
n_factors = 16
n_epochs = 100
lr = 0.05
reg = 0.01
init_mean = 0.0
init_std_dev = 0.001# 自动超参数优化(20 次试验)
[recommend.collaborative]
optimize_trials = 20# 早停策略
[recommend.collaborative.early_stopping]
patience = 5

示例 3:查看训练进度

# 启动 Master
gorse-master --config config.toml# 查看日志
tail -f /var/log/gorse/master.log# 输出示例:
# [INFO] fit bpr 10/100 NDCG@10=0.3521 Precision@10=0.2134 Recall@10=0.1892
# [INFO] fit bpr 20/100 NDCG@10=0.4123 Precision@10=0.2567 Recall@10=0.2341
# [INFO] fit bpr 30/100 NDCG@10=0.4567 Precision@10=0.2891 Recall@10=0.2678
# ...
# [INFO] fit bpr complete NDCG@10=0.5123 Precision@10=0.3456 Recall@10=0.3234

示例 4:Prometheus 监控

# 查看训练指标
curl http://master-host:8086/metrics | grep collaborative# 输出示例:
# gorse_master_collaborative_filtering_fit_seconds 125.5
# gorse_master_collaborative_filtering_ndcg_10 0.5123
# gorse_master_collaborative_filtering_precision_10 0.3456
# gorse_master_collaborative_filtering_recall_10 0.3234

示例 5:对比 BPR vs ALS

// 训练 BPR
bpr := cf.NewBPR(model.Params{model.NFactors: 16,model.Lr:       0.05,model.Reg:      0.01,model.NEpochs:  100,
})
bprScore := bpr.Fit(ctx, trainSet, testSet, config)// 训练 ALS
als := cf.NewALS(model.Params{model.NFactors: 16,model.Reg:      0.06,model.Alpha:    0.001,model.NEpochs:  50,
})
alsScore := als.Fit(ctx, trainSet, testSet, config)// 对比
fmt.Printf("BPR NDCG: %.4f, ALS NDCG: %.4f\n", bprScore.NDCG, alsScore.NDCG)
fmt.Printf("BPR Recall: %.4f, ALS Recall: %.4f\n", bprScore.Recall, alsScore.Recall)

总结:协同过滤训练要点

1. 核心概念

概念 说明
矩阵分解 将稀疏的用户-物品矩阵分解为两个低维稠密矩阵
隐向量 用户/物品的 K 维向量,捕捉潜在特征
BPR 通过成对比较学习排序(适合 Top-N 推荐)
ALS 交替优化用户和物品向量(收敛快)

2. 训练流程

加载数据↓
划分训练集/测试集↓
初始化用户/物品向量(随机)↓
迭代训练(100 轮)├─ BPR: 采样三元组 (u,i,j),SGD 更新└─ ALS: 交替优化用户/物品向量↓
定期评估(NDCG, Precision, Recall)↓
构建 HNSW 索引(加速搜索)↓
保存模型

3. 关键参数

参数 推荐值 影响
n_factors 16-32 模型容量
n_epochs 50-200 训练时间
lr (BPR) 0.01-0.1 收敛速度
reg 0.001-0.1 过拟合控制

4. 性能优化

  • 并行训练:使用多核 CPU(jobs=8
  • 早停策略:防止过拟合(patience=5
  • 超参数优化:自动搜索最佳参数(optimize_trials=20
  • HNSW 索引:加速推荐生成(O(log N) 复杂度)

5. 实战建议

  1. 小数据集:BPR + 16 维 + 100 轮
  2. 大数据集:ALS + 32 维 + 50 轮
  3. 冷启动严重:增加 init_std_dev
  4. 过拟合:增加 reg 或减少 n_factors
  5. 收敛慢:调整 lr 或启用早停

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1182732.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

毫米波雷达十年演进

下面这份内容&#xff0c;不是“毫米波雷达从 77GHz 到 4D 成像”的产品路线图&#xff0c;也不是“毫米波是不是只能做 ACC 的老传感器”的工程偏见&#xff0c;而是站在 “毫米波雷达作为自动驾驶系统中唯一天然具备‘速度、距离、存在性’鲁棒感知能力的物理安全传感器”高度…

【小程序毕设全套源码+文档】基于微信小程序的农产品管理与销售APP设计与实现(丰富项目+远程调试+讲解+定制)

博主介绍&#xff1a;✌️码农一枚 &#xff0c;专注于大学生项目实战开发、讲解和毕业&#x1f6a2;文撰写修改等。全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围&#xff1a;&am…

2026年比较好的保温装饰一体板,装饰保温一体板,真石漆保温装饰一体板厂家新品推荐榜 - 品牌鉴赏师

引言在建筑行业蓬勃发展的当下,保温装饰一体板作为一种集保温与装饰功能于一体的新型建筑材料,正逐渐成为市场的主流选择。它不仅能有效提升建筑物的保温性能,降低能源消耗,还能为建筑外观增添美观度。目前市场上的…

new python project setup

python. UV + Ruff + ty + pytest + coveragepygithub. pre-commit + depedabot security checker + template for pull requests + template for issue + GitHub Actions CI (tests, type, lint, coverage upload) + …

Napprenez pas lamricain, lukrainien ou le russe

Dautres me lont dit, bien quils naient pas mentionn la langue amricaine.

【信息科学与工程学】第二篇 材料工程01 材料科学 (1)

材料科学核心知识体系&#xff1a;标准、概念、规则与方程一、材料科学全领域判断逻辑总图二、材料标准体系框架1. 国际标准体系概览标准体系主要制定机构适用范围典型标准系列ISO标准​国际标准化组织全球通用ISO 9001&#xff08;质量体系&#xff09;&#xff0c;ISO 6892&a…

实验用冻干机常见故障诊断与日常维护策略 - 品牌推荐大师

实验用真空冷冻干燥机(简称冻干机)是生物、医药、材料等领域保存热敏性样品的核心设备,其故障会直接影响样品活性与实验进度。本文结合设备工作原理,梳理常见故障诊断方法与全周期日常维护策略,帮助提升设备运行稳…

STM32F4的CAN升级方案 bootloader源代码,对应测试用app源代码,都是kei...

STM32F4的CAN升级方案 bootloader源代码&#xff0c;对应测试用app源代码&#xff0c;都是keil工程&#xff0c;代码有备注&#xff0c;也有使用说明。 带对应上位机可执行文件。 上位机vs2013开发(默认exe&#xff0c;源代码需要额外拿)STM32F4 系列 MCU 的在线升级&#xff0…

【小程序毕设源码分享】基于springboot+微信小程序的办公用品管理系统小程序的设计与实现(程序+文档+代码讲解+一条龙定制)

博主介绍&#xff1a;✌️码农一枚 &#xff0c;专注于大学生项目实战开发、讲解和毕业&#x1f6a2;文撰写修改等。全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围&#xff1a;&am…

2026年留学申请机构推荐:多国申请场景深度评价,针对文书质量与录取率核心痛点 - 品牌推荐

摘要 随着全球高等教育交流的持续深化与人才流动的加速,选择专业的留学申请服务机构已成为众多学子及家庭规划海外求学路径时的普遍考量。面对信息过载、院校政策动态变化以及跨文化申请流程的复杂性,决策者往往陷入…

spring boot的@Async注解有什么坑?

Spring Boot 中 Async 注解的常见坑&#xff08;2025-2026 真实生产环境高频问题汇总&#xff0c;按严重程度排序&#xff09; 排名坑的名称严重程度发生概率典型表现/后果解决/规避方案&#xff08;推荐做法&#xff09;1同一个类内部方法调用不生效★★★★★★★★★★内部…

2026年1月树枝/竹子粉碎机优选厂家:威威机械三十载匠心深耕农林加工领域 - 深度智识库

2026年1月正值农林废弃物集中处理、春季育苗备料的关键周期,树枝、竹子等纤维质物料的高效粉碎需求显著攀升。随着行业集中度逐步提升,深耕小型破碎领域三十余载的郑州市伟巍机械有限公司(旗下品牌“威威”),凭借…

如果希望做c++相关的工作,该如何系统学习c++?

如果希望做 C 相关的工作&#xff0c;该如何系统学习 C&#xff1f; &#xff08;2025-2026 年最现实的就业导向学习路径&#xff09; 以下路径按照真正能找到工作的优先级排序&#xff0c;而不是按照“语言特性出现的先后顺序”。 不同目标对应的现实学习时长与难度对比&…

成都硕士留学中介口碑排名出炉,申请成功率高的机构不容错过 - 留学机构评审官

成都硕士留学中介口碑排名出炉,申请成功率高的机构不容错过一、成都硕士留学中介如何选择?高成功率机构有哪些?在搜索引擎中,“成都硕士留学中介哪家好?”、“成都留学机构申请成功率高吗?”是本地学生与家长反复…

Qwen Code CLI - Skill引用

前提:最新版Qwen Code CLI 目前skill还只是实验性特性,文档中强调需要通过--experimental-skills启用,但后面又说明可通过setting配置开启此特性 Agent 技能(实验性) | Qwen Code Docs CLI形式:qwen --experimen…

长沙Top10研究生留学机构推荐:收费透明,服务优质 - 留学机构评审官

长沙Top10研究生留学机构推荐:收费透明,服务优质一、 如何筛选值得信赖的长沙研究生留学中介?在长沙寻求研究生留学服务的学生与家长,常常面临几个核心关切:如何确保中介费用的透明度,避免后续隐形消费?服务流程…

Kdenlive v25.12.1:免费开源多轨道视频剪辑工具

Kdenlive v25.12.1 是一款基于 Qt、KDE 及 MLT 框架构建的免费开源专业视频剪辑工具&#xff0c;集成 FFmpeg 开源工具&#xff0c;支持多轨道编辑、全格式兼容等核心功能&#xff0c;无论是基础剪辑需求还是专业创作场景&#xff0c;都能为用户提供流畅且强大的视频编辑体验。…

B站m4s视频快速转换完整教程:轻松突破播放限制

B站m4s视频快速转换完整教程&#xff1a;轻松突破播放限制 【免费下载链接】m4s-converter 将bilibili缓存的m4s转成mp4(读PC端缓存目录) 项目地址: https://gitcode.com/gh_mirrors/m4/m4s-converter 还在为B站缓存视频无法在其他设备播放而烦恼吗&#xff1f;那些精心…

福州Top10研究生留学机构,高录取率如何助力留学申请成功? - 留学机构评审官

福州Top10研究生留学机构,高录取率如何助力留学申请成功?我是一名从业八年的国际教育规划师,日常工作便是为不同背景的学生剖析留学申请的底层逻辑,并协助他们筛选合适的支持资源。在福州,许多意向深造的研究生申…

Taro跨端开发:从“多端适配焦虑“到“一次编写,处处运行“的蜕变之旅

Taro跨端开发&#xff1a;从"多端适配焦虑"到"一次编写&#xff0c;处处运行"的蜕变之旅 【免费下载链接】taro 开放式跨端跨框架解决方案&#xff0c;支持使用 React/Vue/Nerv 等框架来开发微信/京东/百度/支付宝/字节跳动/ QQ 小程序/H5/React Native 等…