[AutoParallel] fix sp test accuracy for H20 #76150
Open
+1
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
Auto Parallel
PR Types
Bug fixes
Description
test_semi_auto_parallel_for_llama_subnet 在 H20 卡上存在精度问题:测试 DemoNet 跑 dp+mp+pp 时,开关 sp 是否可以精度 rtol=1e-7 对齐,在 H20 上时精度不能对齐的
分析:
1、开关 sp 前反向的切分状态和通信都是符合预期的
2、开关 sp 操作是在 layer norm 前操作的,这造成了 精度 diff
关 sp 时,layer norm 的输入是 partial 状态,输入会先经过 SPMD 的allreduce 成 replicate 状态,结果也是 replicate,这个计算的结果是 global 的
开 sp 时,partial 状态被 sp 的 reduce_scatter 成 shard 状态,layer norm 的输入是 shard 状态,这时 SPMD 的结果是 shard 状态,这个结果是 local 的
所以这两个结果是有差异的,正常情况下,diff 是很小的,最终是不影响收敛性的
但是,这个DemoNet 的 loss 是1e5级别且不收敛的,对微小的精度 diff 很敏感,小的精度差异会经过传播逐渐扩大,造成单测试精度对不齐
3、hack 确认:开 sp 后,在 norm 前进行 reshard 成 replcitae,此时 开关 sp 是可以精度对齐的
修复:
diff 是当前组网方式造成的,paddle 机制是没问题的,改一下组网方式比较好
它组网中在 norm 前有 out = out + tgt 残差,将这行不注释,开 sp 时,这里因为 shard(0) + shard(1)相加,引入 allreduce,使得 norm 计算的是 global 的值,开关 sp 精度可以对齐
PCard-93779