很多时候嵌入式或者新硬件需要纯净的权重模型和激活值(运行时中间值),本文提供一种最简洁的方法。
假设已经有模型model和pt文件了,在当前目录下新建weights文件夹,运行最后三行代码,就可以得到模型的权重(文本形式和二进制形式)
model.load_state_dict(state_dict)for name, param in model.named_parameters():print(param.data.numpy(),file=open(f"weights/{name}.txt", "w"))param.data.numpy().tofile(f"weights/{name}.bin")
对于二进制形式的文件,可以通过od -t f4 <binary file name>
查看其对应的浮点数值。f4
表示fp32.
打印forward的中间值:
def hook_fn(module, input, output):key = str(module)intermediate_outputs = {}intermediate_outputs[key+"-input"] = inputintermediate_outputs[key+"-output"] = outputprint(intermediate_outputs)def register_hooks(model):for name, layer in model.named_children():# print(name, layer) # dump all layerslayer.register_forward_hook(hook_fn)# Recursively apply the same to all submodulesregister_hooks(layer)register_hooks(model)
其中regster_hooks
和以下等价(不需要recursive了)
def register_hooks(model):for name, layer in model.named_modules():# print(name, layer) # dump all layerslayer.register_forward_hook(hook_fn)
其中nn.sequential
作为一个整体,目前没办法拆开来看其内部的中间值。