-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Yuan-2.0-2.1B的多机分布式预训练问题:FileNotFoundError: [Errno 2] No such file or directory:'home/Yuan-2.0-main/megatron/fused_kernels/build/lock' #86
Comments
麻烦确认一下代码路径是否在多机共享的目录下。 |
我也遇到了同样的问题,单机多卡脚本正常运行,多机多卡就会遇到这个cpp编译读写锁造成的bug。目前我已经解决这个问题了,只要在训练脚本中为 torchrun $DISTRIBUTED_ARGS pretrain_yuan.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
$LOG_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--no-masked-softmax-fusion # <--- 加上这个参数 p.s. 这个问题其实是 Megatron-LM 的 bug,之前也有人在那边提过相同的 issue,目前看来 Nvidia 的开发者应该已经修复了这个 bug。在这个回复中,他们提到:
我对比了一下出 bug 的代码片段,发现他们的确已经移除了与 fused_kernel 编译相关的代码(对比这里),建议 Yuan-2.0 更新一下目前的 megatron 相关的代码。 |
的确是在共享多机目录下的 |
问题得到解决,十分感谢 |
代码在单机上可以跑通,到多机上的master会卡在这一步:
setting number of micro-batches to constant 96
另外一个worker会报错:
Failures:
<NO_OTHER_FAILURES>
Root Cause (first observed failure):
[0]:
time : 2024-01-05_13:42:59
host : pytorch-2e4abf84-worker-0
rank : 7 (local_rank: 3)
exitcode : 1 (pid: 104)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
根据报错是cpp的编译问题出错,home/Yuan-2.0-main/megatron/fused_kernels/build目录下没有lock文件。搜索了一些相关问题,有说把build目录删除重新编译,尝试过发现也不行。
The text was updated successfully, but these errors were encountered: