view
PyTorch 的view() 是张量「重塑(Reshape)」函数,用于改变张量的维度形状但不改变数据本身
在多头注意力中,view()的核心作用是将总隐藏维度拆分为「注意力头数 × 单头维度」,实现多头并行计算
核心规则
tensor.view(*shape)作用:将张量重塑为指定的shape,要求「新形状的元素总数 = 原张量的元素总数」(否则报错)
核心特性:
不改变张量的底层数据,仅改变维度的 “视图”(轻量操作,无数据拷贝)
重塑后的张量与原张量共享内存(修改一个,另一个也会变)
支持用-1自动推导某一维度的大小(仅能有一个-1)
importtorch# 一维张量重塑为二维x=torch.arange(12)# shape=(12,),元素总数=12x_view1=x.view(3,4)# shape=(3,4),3×4=12x_view2=x.view(4,-1)# -1自动推导为3,shape=(4,3)print(x_view1.shape)# torch.Size([3,4])print(x_view2.shape)# torch.Size([4,3])# 三维张量重塑(核心:元素总数不变)y=torch.randn(2,6,768)# 2×6×768=9216y_view=y.view(2,6,12,64)# 2×6×12×64=9216print(y_view.shape)# torch.Size([2,6,12,64])- 关键注意事项
报错场景:新形状元素总数≠原总数 → x.view(3,5)(12≠15)会报错
-1的用法:仅能指定一个-1,用于自动计算维度(如view(2,-1,64))
内存连续性:若张量内存不连续(如经过transpose/permute),需先调用contiguous()再view,否则报错
多头注意力中view的核心作用
将总隐藏维度d_model拆分为num_heads × d_k(单头维度),view()是实现这一拆分的关键
完整流程如下:
步骤 1:先明确核心参数(以 BERT-base 为例)
batch_size=2(批次)、seq_len=6(序列长度);
d_model=768(总隐藏维度)、num_heads=12(注意力头数)、d_k=64(单头维度,768=12×64);
输入querys:shape=(2,6,768)(经W_query线性变换后的输出)。
步骤 2:用view()拆分注意力头
# 1. 原始querys:[batch, seq_len, d_model] = [2,6,768]querys=torch.randn(2,6,768)# 2. 拆分为多头:[2,6,12,64](batch, seq_len, num_heads, d_k)querys_heads=querys.view(2,6,12,64)print(querys_heads.shape)# torch.Size([2,6,12,64])# 3. 转置调整维度(为后续批量矩阵乘法):[2,12,6,64]# 注:transpose后内存不连续,需contiguous()才能再viewquerys_heads=querys_heads.transpose(1,2).contiguous()print(querys_heads.shape)# torch.Size([2,12,6,64])步骤 3:注意力计算后,用view()合并多头
# 假设注意力计算后的输出:[2,12,6,64](batch, num_heads, seq_len, d_k)attn_output=torch.randn(2,12,6,64)# 1. 先转置回原维度:[2,6,12,64]attn_output=attn_output.transpose(1,2).contiguous()# 2. 合并多头:[2,6,768](还原为总隐藏维度)attn_output_merged=attn_output.view(2,6,768)print(attn_output_merged.shape)# torch.Size([2,6,768])多头注意力完整实战代码
importtorchimporttorch.nnasnnclassMultiHeadAttention(nn.Module):def__init__(self,d_model=768,num_heads=12):super().__init__()self.d_model=d_model self.num_heads=num_heads self.d_k=d_model//num_heads# 64,用//保证整除# 定义Q/K/V线性层self.W_query=nn.Linear(d_model,d_model)self.W_key=nn.Linear(d_model,d_model)self.W_value=nn.Linear(d_model,d_model)defforward(self,x):# x: [batch_size, seq_len, d_model] = [2,6,768]batch_size,seq_len=x.shape[0],x.shape[1]# 1. 线性变换:Q/K/V均为[2,6,768]Q=self.W_query(x)K=self.W_key(x)V=self.W_value(x)# 2. 拆分为多头:[2,6,12,64]Q=Q.view(batch_size,seq_len,self.num_heads,self.d_k)K=K.view(batch_size,seq_len,self.num_heads,self.d_k)V=V.view(batch_size,seq_len,self.num_heads,self.d_k)# 3. 转置:[2,12,6,64](batch, num_heads, seq_len, d_k)# 必须contiguous(),否则后续view会报错Q=Q.transpose(1,2).contiguous()K=K.transpose(1,2).contiguous()V=V.transpose(1,2).contiguous()# 4. 计算注意力分数:Q @ K^T → [2,12,6,6]K_T=K.transpose(2,3)# [2,12,64,6]attn_scores=Q @ K_T# [2,12,6,6]# 5. softmax归一化(省略,核心看view)attn_weights=torch.softmax(attn_scores,dim=-1)# 6. 加权求和:[2,12,6,64]attn_output=attn_weights @ V# 7. 转置+合并多头:[2,6,768]attn_output=attn_output.transpose(1,2).contiguous()attn_output=attn_output.view(batch_size,seq_len,self.d_model)returnattn_output测试代码
mha=MultiHeadAttention(d_model=768,num_heads=12)x=torch.randn(2,6,768)output=mha(x)print(output.shape)# 输出:torch.Size([2,6,768])view vs reshape
新手常混淆view和reshape,二者均用于重塑张量,但核心差异如下:
特性 view() reshape()
内存共享 与原张量共享内存(无拷贝) 优先共享内存,不连续则拷贝新内存
内存连续性 要求张量内存连续(否则报错) 自动处理内存不连续,无需contiguous()
适用场景 内存连续的张量(如线性层输出) 内存不连续的张量(如 transpose 后)
大模型开发建议
若确定张量内存连续(如线性层输出、原始输入),用view()(更高效)
若张量经过transpose/permute(如多头注意力中的转置),用reshape()(无需手动contiguous())
permute:重排,置换
示例:
# 替代:transpose后直接reshape,更简洁Q=Q.transpose(1,2).reshape(batch_size,self.num_heads,seq_len,self.d_k)总结
view()核心作用:改变张量维度形状,不改变数据,要求元素总数不变,支持-1自动推导维度
多头注意力中view()的核心用法
拆分:将[batch, seq_len, d_model]拆为[batch, seq_len, num_heads, d_k]
合并:注意力计算后,将[batch, seq_len, num_heads, d_k]合并回[batch, seq_len, d_model]
关键注意:transpose/permute后需contiguous()才能用view(),或直接用reshape()更便捷
contiguous
contiguous()用于将「内存不连续」的张量转换为「内存连续」的张量,保证张量的元素在内存中按维度顺序紧密排列
是view()等操作的前置必要条件
张量在计算机内存中是一维线性存储的
“连续” 指的是:张量的元素在内存中的排列顺序,和按「维度顺序(如从 0 维到最后一维)」遍历张量得到的顺序完全一致
直观示例(二维张量)
假设有张量x = torch.tensor([[1,2,3], [4,5,6]])(shape=(2,3)):
连续内存布局:内存中存储顺序是 1 → 2 → 3 → 4 → 5 → 6(按 “行优先” 遍历,先遍历 0 维,再遍历 1 维)
若对x做转置x.T,得到[[1,4], [2,5], [3,6]]
转置后的张量逻辑上是 3 行 2 列,但内存中仍存储为1→2→3→4→5→6(PyTorch 的transpose/permute仅修改 “维度视图”,不拷贝数据)
此时按转置后的维度遍历(行优先),期望顺序是1→4→2→5→3→6,但内存实际顺序不符 → 转置后的张量是内存不连续的
为什么张量会变得 “不连续”?
PyTorch 中以下操作会导致张量内存不连续(核心是 “只改视图,不改内存”):
- 维度变换类:transpose()、permute()(最常见,如多头注意力中的维度交换)
- 索引 / 切片类:非连续切片(如x[:, ::2])、高级索引
- 其他操作:narrow()、expand()(部分场景)
这些操作的设计初衷是 “轻量”—— 避免不必要的数据拷贝,提升效率,但代价是破坏了内存连续性
contiguous()的核心作用
contiguous()会创建一个新的内存连续的张量:
- 新张量与原张量数据相同,但内存排列会按照 “当前维度顺序” 重新整理
- 新张量与原张量不再共享内存(是数据拷贝操作)
- 只有内存连续的张量,才能调用view()(view()要求张量元素在内存中是连续的,否则无法正确重塑维度)
实战示例(结合多头注意力的经典场景)
importtorch# 1. 创建连续张量x=torch.randn(2,6,768)# shape=(2,6,768),内存连续print(x.is_contiguous())# 输出:True# 2. 转置后内存不连续x_trans=x.transpose(1,2)# 交换1、2维,shape=(2,768,6)print(x_trans.is_contiguous())# 输出:False# 3. 直接调用view()会报错try:x_trans.view(2,768,12,5)# 768×6=12×60?不,768×6=4608=12×384,这里故意错,核心看报错exceptExceptionase:print("报错:",e)# 报错:view size is not compatible with input tensor's size and stride...# 4. 先contiguous()再view(),正常运行x_contig=x_trans.contiguous()print(x_contig.is_contiguous())# 输出:Truex_view=x_contig.view(2,768,12,64)# 2×768×12×64=2×768×768=1179648,和2×768×6=9216?哦,修正:x_trans.shape=(2,768,6),总元素=2×768×6=9216;view为2,768,12,0.5?重新来:x_view=x_contig.view(2,768,12,0.5)# 故意错,实际应保证总元素一致:x_view=x_contig.view(2,768,12,0.5)→ 正确示例: x_contig=x_trans.contiguous()x_view=x_contig.view(2,12,64,6)# 2×12×64×6=2×768×6=9216,总元素一致print(x_view.shape)# 输出:torch.Size([2, 12, 64, 6])大模型开发中的核心应用场景(必掌握)
contiguous()几乎只在「transpose()/permute() + view()」的组合中使用,尤其是多头注意力层:
# 多头注意力中拆分注意力头的标准流程Q=torch.randn(2,6,768)# [batch, seq_len, d_model]Q=Q.view(2,6,12,64)# 拆分为多头:[2,6,12,64]Q=Q.transpose(1,2)# 交换维度:[2,12,6,64] → 内存不连续Q=Q.contiguous()# 转为连续内存后续可安全调用view()(若需要)
Q=Q.view(2,12,6,64)# 正常运行关键注意事项
contiguous()是数据拷贝操作:会消耗内存和时间,非必要时不要调用(比如仅做矩阵乘法,无需连续内存)
替代方案:reshape()会自动处理内存连续性(优先共享内存,不连续则自动拷贝),因此:
若只需重塑维度,用reshape()替代contiguous()+view()更简洁
示例:Q.transpose(1,2).reshape(2,12,6,64)(无需手动contiguous())
判断是否连续:用tensor.is_contiguous()快速检查,返回True则为连续
总结
contiguous()的核心:将内存不连续的张量转为连续,保证view()等操作能正常执行
触发场景:张量经过transpose()/permute()后,若要调用view(),必须先contiguous()
大模型实战建议:优先用reshape()替代contiguous()+view(),减少代码量且更安全