思路很简单:把 LLaVA 的所有参数 requires_grad=False,只训练你自己的模块(比如 mask 头)。下面给出几种常见写法,任选其一。
- 冻结整个 LLaVA(vision + projector + language model)
def freeze_llava(llava):for p in llava.parameters():p.requires_grad = False# 你的封装里(以 XVQAModel 为例)
model = XVQAModel(llava_model=llava, ...)
freeze_llava(model.llava)# 只把需要训练的模块交给优化器(例如 mask 头)
optimizer = torch.optim.AdamW(params=[p for p in model.mask_head_deocder.parameters() if p.requires_grad],lr=1e-4
)
记得优化器只传入可训练参数,否则会把被冻结的也塞进去(虽然不会更新,但浪费时间/显存)。
- 选择性冻结(只留 projector 或某些层可训)
只训练 mm_projector(常见微调方式):
def freeze_all_but_projector(llava):for p in llava.parameters():p.requires_grad = Falsefor p in llava.mm_projector.parameters():p.requires_grad = Truefreeze_all_but_projector(model.llava)
optimizer = torch.optim.AdamW(list(model.mask_head_deocder.parameters()) + list(model.llava.mm_projector.parameters()),lr=1e-4
)
只训练语言模型最后 N 层(举例 N=2):
def freeze_all_but_last_n_transformer_blocks(llava, n=2):for p in llava.parameters():p.requires_grad = False# LLaVA 的 LLM 一般在 llava.model.layersfor blk in llava.model.layers[-n:]:for p in blk.parameters():p.requires_grad = Truefreeze_all_but_last_n_transformer_blocks(model.llava, n=2)
只训练视觉塔(不常见):
def freeze_all_but_vision(llava):for p in llava.parameters():p.requires_grad = Falsefor p in llava.vision_tower.parameters():p.requires_grad = True
- 冻结后的小检查
def count_trainable_params(m):return sum(p.numel() for p in m.parameters() if p.requires_grad)print("trainable in XVQAModel:", count_trainable_params(model))
print("trainable in LLaVA:", count_trainable_params(model.llava))
print("trainable in mask head:", count_trainable_params(model.mask_head_deocder))