非常好,这段代码是一个非常典型的 L2 正则化(L2 Regularization) 的实现例子。
我们来一步步讲清楚它的作用、为什么写成这样,以及 torch.norm(param, p=2)
的意义。
一、先看这段代码在干什么
# 手动计算 L2 正则化项并加到总损失中
l2_reg = 0
for param in model.parameters():l2_reg += torch.norm(param, p=2) # 计算每个参数的 L2 范数(即平方和的平方根)# 最终损失 = 原始损失 + L2 正则化项
total_loss = loss + l2_lambda * l2_reg
这段代码的逻辑是:
- 遍历模型的所有参数;
- 计算每个参数张量的 L2 范数(即权重的平方和开方);
- 把所有参数的 L2 范数加起来,得到正则项;
- 把它乘以一个超参数
l2_lambda
(正则化系数); - 最后与原始
loss
相加,形成新的总损失total_loss
。
二、为什么要写 torch.norm(param, p=2)
torch.norm
是计算张量的范数(norm)的函数。
常用形式:
torch.norm(tensor, p=2)
表示计算 L2 范数,即:
举个例子:
import torch
w = torch.tensor([3.0, 4.0])
torch.norm(w, p=2) # = sqrt(3² + 4²) = 5
在机器学习里,我们并不是单纯想知道权重有多大,而是希望:
在损失函数中惩罚权重过大的情况,防止模型过拟合。
于是我们把这个 L2 范数作为惩罚项加入总损失里。
三、为什么要加 L2 正则化项?
这是为了防止模型过拟合(overfitting)。
原始损失(如交叉熵)只关注“模型预测得对不对”;
但我们希望模型的参数也“不要太极端”(太大、太尖锐)。
所以我们加入一项惩罚:
其中 (\lambda) 是权重衰减系数,控制惩罚力度。
它的意义是:
- 如果权重过大 → 惩罚项大 → 让优化器“缩小”权重;
- 如果权重适中 → 惩罚项小 → 不影响原始损失。
这相当于告诉模型:
“别太依赖某个特征,保持整体权重平衡。”
四、L2 正则化的数学形式
从数学上讲,
L2 正则化项是:
而 PyTorch 的 torch.norm(param, p=2)
返回的是:
通常我们会用平方形式(不取平方根),即:
如果你想完全对应公式,可以改成:
l2_reg += torch.norm(param, p=2) ** 2
这样结果是 L2 范数平方,更严格地符合理论定义。
五、它和 weight_decay
的关系
其实你不用手动写这段代码也可以实现同样效果。
在 PyTorch 的优化器中,比如 Adam、SGD,都有一个参数叫:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
weight_decay
本质上就是在每次梯度更新时加上一个 L2 正则项(权重衰减项),
等价于你手动加的这一段。
手动写通常是为了:
- 控制哪些层加正则、哪些不加;
- 或者在复杂损失中单独调整正则项权重。
六、总结一句话
torch.norm(param, p=2)
是在计算参数张量的 L2 范数(权重平方和的平方根),
它作为惩罚项加入总损失中,起到限制模型复杂度、防止过拟合的作用。
✅ 结构化总结:
部分 | 含义 |
---|---|
torch.norm(param, p=2) |
计算每个参数的 L2 范数 |
l2_reg += ... |
累加所有参数的范数 |
total_loss = loss + λ * l2_reg |
在总损失中加上正则惩罚 |
目的 | 抑制权重过大、防止过拟合 |
对应公式 | ( \mathcal{L}{total} = \mathcal{L} + \lambda \sum_i w_i^2 ) |