现象: loss不下降
过程如下:
1.减少层数,准备最小复现环境
 2.dropout设置为0,重复运行二次,对比loss是否一致
 3.第二次迭代开始loss不一致
 4.对比backward之后的梯度,发现某一个梯度不一致
 5.dump得到所有算子的规模,单算子测试功能正常
 6.怀疑是内存越界导致
 7.排除通信库的问题,逐算子bypass
 8.dump reduce_scatter的输入,发现每次都不样
 9.在异常的时候pause进程,在python调用reduce_scatter的位置打印调用栈
 10.定位到有问题的模块,是一个融合算子
 11.用普通算子替换,结果一致
 12.复测这个规模的融合算子功能正常
 13.怀疑算子内部有内存踩踏行为
 14.将输入类型从fp16改为fp32,结果正常
 15.review该算子内部实现,确实有几行代码将输入当fp32处理