Skip to content

[NPU] feat: Support FSDP worker and vLLM Ascend#332

Merged
vermouth1992 merged 23 commits intoverl-project:mainfrom
sunyi0505:vllm-0.7-npu
May 23, 2025
Merged

[NPU] feat: Support FSDP worker and vLLM Ascend#332
vermouth1992 merged 23 commits intoverl-project:mainfrom
sunyi0505:vllm-0.7-npu

Conversation

@sunyi0505
Copy link
Collaborator

@sunyi0505 sunyi0505 commented Feb 21, 2025

For developers, you can follow the docs: docs/ascend/ascend.rst

This pr is committed for supporting Ascend NPU backend.
Co-authored-by: Chendong98 chendong136@huawei.com
Co-authored-by: zheliuyu 15750543867@163.com
Co-authored-by: celestialli celestialli@outlook.com
In this pr, we add the capability to determine the type of NPU device and we also add a new script for training on NPU.

These are change lists:

  1. pyproject.toml change verison of vllm
  2. requirements-npu.txt requirements for NPU
  3. verl/bert_padding.py Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
  4. verl/single_controller/ray/base.py
  5. verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py
  6. verl/trainer/fsdp_sft_trainer.py
  7. verl/utils/flops_counter.py
  8. verl/utils/fsdp_utils.py
  9. verl/workers/actor/dp_actor.py
  10. verl/workers/critic/dp_critic.py
  11. verl/workers/fsdp_workers.py
  12. verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
  13. verl/workers/sharding_manager/fsdp_vllm.py
  14. verl/utils/device.py get device type for different device
  15. docs/ascend/ascend.md

Here are our roadmap:

RoadMap

  • sft
  • ppo
  • grpo

News

[2025.03.31] Add result of SFT and GRPO. Qwen2-7B-Instruct was tested on 2*8 devices, and many params related to batch_size need to be reduced. So this result is only for reference. We will announce the reward results of the default params as soon as sleep mode is supported.

[2025.03.03] Modify the adaptation method of Ray

[2025.02.25] The PPO algorithm is supported for training on NPU with the FSDP backend.

[2025.02.23] The SFT algorithm is supported for training on NPU with the FSDP backend.

[2025.02.21] The GRPO algorithm is supported for training on NPU with the FSDP backend.

Requirements
We use this PR testing on Ascend NPU and GPU to ensure the same codes can run on different devices. The device information is 8 Atlas 800T A2 and 8 A100. Other software information is shown in the following table.

Software Version
transformers 4.47.1
accelerate 1.3.0
torch_npu 2.5.1.rc1
CANN 8.1.RC1 (Not Released)

About mean error
Due to differences in hardware structure, we cannot guarantee that the loss of Ascend NPU is exactly the same as that of the GPU. According to our experience, the loss differences less than 2% is acceptable. If the loss difference is greater than 2%, we will try to fix it. The calculation formula is as follows.
loss_comparison

N represents the number of training steps. For more information, please refer to Calculation accuracy description

@sunyi0505 sunyi0505 changed the title support ASCEND NPU [WIP] support ASCEND NPU Feb 21, 2025
@huangk10
Copy link

does this pr work on multi nodes?

@sunyi0505
Copy link
Collaborator Author

does this pr work on multi nodes?

I am currently conducting tests on a single node only, and will subsequently supplement with multi-node testing results.

@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 2 times, most recently from 0afd136 to d496b70 Compare February 21, 2025 07:59
@sunyi0505 sunyi0505 changed the title [WIP] support ASCEND NPU Support FSDP worker and vLLM Ascend Feb 21, 2025
@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 10 times, most recently from 8b1b207 to 0b7e274 Compare February 22, 2025 06:48
@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 2 times, most recently from 62af61c to fd62e2e Compare February 24, 2025 01:27
@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 3 times, most recently from 45f208b to d36c1c7 Compare February 25, 2025 08:07
@CLAassistant
Copy link

CLAassistant commented Feb 26, 2025

CLA assistant check
All committers have signed the CLA.

@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 3 times, most recently from 6314fcf to d4309a8 Compare March 3, 2025 07:21
@zheliuyu
Copy link
Contributor

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

@sunyi0505 sunyi0505 requested a review from antonlisq May 16, 2025 09:22
@sunyi0505
Copy link
Collaborator Author

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

Thank you for your suggestion. I will make the necessary changes in the future.

vLLM
------

为了保证能够在 verl 上正常使用 vLLM,需要安装 vLLM Ascend 插件(`vllm-ascend`)。关于在华为昇腾上支持的 vLLM 版本以及和 vLLM Ascend 的配套关系请参考`安装教程 <https://vllm-ascend.readthedocs.io/en/v0.7.1rc1/installation.html>`_。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个安装教程的URL指向的是v0.7.1rc1版本,建议调整为v0.7.3版本的文档

为了保证能够在 verl 上正常使用 vLLM,需要安装 vLLM Ascend 插件(`vllm-ascend`)。关于在华为昇腾上支持的 vLLM 版本以及和 vLLM Ascend 的配套关系请参考`安装教程 <https://vllm-ascend.readthedocs.io/en/v0.7.1rc1/installation.html>`_。

------
Ray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ray这部分下方貌似没有对应的内容,请确认,如果确实没有内容,建议删除章节标题

精度对比
------

根据经验,对于SFT等微调算法,我们期望在相同配置下,在华为昇腾设备上的 Loss 与英伟达 GPU 的 Loss 平均误差小于 2%,具体计算方式如下:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议调整为“平均绝对误差小于等于2%”


其中,N 表示训练的步数。更多信息请参考[精度计算说明](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html)。

根据经验,对于GRPO等强化学习算法,我们期望在相同配置下,在华为昇腾设备上的 reward 与英伟达 GPU 的 reward 平均绝对误差小于 4%,具体计算参考 Loss 计算。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议调整为“平均绝对误差小于等于4%”

pybind11
pylatexenc
ray
tensordict<0.6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里做最高版本的限制的原因是什么?

import torch
import torch.distributed
from filelock import FileLock
from verl.utils.device import is_cuda_available, is_npu_available
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verl内部的模块,导入顺序应放在三方库之后

attn_output = self.proj(attn_output)
return attn_output


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在此处添加做patching的原因,表明后续解决后会移除该patching

from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

import verl.utils.torch_functional as verl_F
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处应该是多删除的,需要还原,和图模式相关

from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from verl.utils.device import get_device_name, is_cuda_available, get_torch_device, is_npu_available
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调整导入顺序为from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available

world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()))
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处改动请确认,非必要请对齐main分支

@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 2 times, most recently from ef80e67 to 58b943b Compare May 20, 2025 03:32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然暂时不支持VL模型,第一个PR建议先不添加这个patch,后续vllm-ascend明确支持后加入。

verl/__init__.py Outdated
patch_hub()

if is_npu_available:
from .utils import npu_patch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如上,相应调整。

@@ -0,0 +1,30 @@
# Tested with 1 & 8 NPUs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个1 & 8 NPUs,是指在1台机器上使用8张卡的意思么?如果是这样,建议调整为Tested on 1 node with 8 NPUs,下方shell脚本同理

@@ -0,0 +1,44 @@
# Tested with 1 & 8 NPUs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议统一文件名命名风格,推荐调整为run_qwen2_5_05b_grpo.sh

verl/__init__.py Outdated
patch_hub()

if is_npu_available:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余空行请移除

@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 2 times, most recently from d151b48 to 36b1589 Compare May 21, 2025 07:32
@vermouth1992 vermouth1992 changed the title [WIP] Support FSDP worker and vLLM Ascend [NPU] feat: Support FSDP worker and vLLM Ascend May 21, 2025
@sunyi0505 sunyi0505 force-pushed the vllm-0.7-npu branch 4 times, most recently from c4cc95d to c5ac75d Compare May 22, 2025 01:21

1. 使用 vLLM,需遵循 vllm-ascend 的安装教程 <https://vllm-ascend.readthedocs.io/en/v0.7.3/installation.html>。
2. 为了能够在 ASCEND NPU 上正常使能 flash_attention_2, transformers 版本需要大于等于 4.52.0。
3. 目前支持 LLM 模型的 GRPO 训练,VLM模型的 GRPO 训练因为 vllm-ascend 的问题将会在后续支持,涉及到的issue为:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下文提及支持SFT,这里描述仅支持GRPO,前后描述矛盾,请确认

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.