diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index b80e5f1d089..456b72a1510 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -26,9 +26,9 @@ jobs: test: name: verl Ascend test (self-host) runs-on: [self-hosted, npu-0] - timeout-minutes: 5 # Increase this timeout value as needed + timeout-minutes: 30 # Increase this timeout value as needed container: - image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10 + image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 volumes: - /usr/local/dcmi:/usr/local/dcmi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi @@ -42,6 +42,13 @@ jobs: --device /dev/hisi_hdc --privileged --network "host" + --shm-size 2g + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable steps: - name: Check npu and CANN info run: | @@ -49,6 +56,42 @@ jobs: npu-smi info - name: Checkout volcengine/verl repo uses: actions/checkout@v4 - - name: Run test + - name: Install torch run: | - lscpu + pip install torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu + pip install torch-npu==2.5.1 + pip install /usr/local/Ascend/ascend-toolkit/latest/lib64/te-0.4.0-py3-none-any.whl + - name: Install vllm + run: | + apt-get update && apt-get install -y git + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git vllm-npu + cd vllm-npu + pip install -r requirements-build.txt + VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ + - name: Install vllm-ascend + run: | + pip list + pip show torch + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + cd vllm-ascend + export COMPILE_CUSTOM_KERNELS=1 + python setup.py install + - name: Install the current repository + run: | + pip3 install hf_transfer peft + pip3 install -r requirements-npu.txt + pip install -e . + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e training tests with LoRA on ASCEND NPU + run: | + ray stop --force + bash tests/e2e/sft/run_sft.sh + rm -rf $HOME/ckpts + - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU + run: | + ray stop --force + bash tests/npu/run_qwen2_5_05b_grpo.sh + rm -rf $HOME/ckpts \ No newline at end of file diff --git a/docs/ascend/ascend_vllm073.rst b/docs/ascend/ascend_vllm073.rst new file mode 100644 index 00000000000..600b64950aa --- /dev/null +++ b/docs/ascend/ascend_vllm073.rst @@ -0,0 +1,82 @@ +verl x Ascend +======== + +我们在 verl 上增加对华为昇腾设备的支持。 + +硬件支持 +======= + +* Atlas 800T A2 + +* Atlas 200T A2 Box16 + +安装 +======= + +环境准备 +------ + ++--------------+----------+ +| 软件 | 版本 | ++-----------+-------------+ +| Python | == 3.10 | +| torch | == 2.5.1 | +| torch_npu | == 2.5.1rc1 | +| CANN | == 8.1.RC1 | ++-----------+-------------+ + +1. 使用 vLLM,需遵循 vllm-ascend 的安装教程 。 +2. 为了能够在 ASCEND NPU 上正常使能 flash_attention_2, transformers 版本需要大于等于 4.52.0。 +3. 目前支持 SFT 与 LLM 模型的 GRPO 训练,VLM模型的 GRPO 训练因为 vllm-ascend 的问题将会在后续支持,涉及到的issue为: + +https://github.com/vllm-project/vllm-ascend/issues/809 + +https://github.com/vllm-project/vllm-ascend/issues/825 + +源码安装 +------ + +.. code-block:: + git clone https://github.com/volcengine/verl.git + cd verl + pip install -r requirements-npu.txt + pip install -e . + +vLLM +------ + +为了保证能够在 verl 上正常使用 vLLM,需要安装 vLLM Ascend 插件(`vllm-ascend`)。关于在华为昇腾上支持的 vLLM 版本以及和 vLLM Ascend 的配套关系请参考`安装教程 `_。 + +其他第三方库说明 +------ + ++--------------+--------+ +| 软件 | 说明 | ++--------------+--------+ +| flash_attn | 不支持 | ++--------------+--------+ +| liger-kernel | 不支持 | ++--------------+--------+ + +精度对比 +------ + +根据经验,对于SFT等微调算法,我们期望在相同配置下,在华为昇腾设备上的 Loss 与英伟达 GPU 的 Loss 平均绝对误差小于等于 2%,具体计算方式如下: + +.. image:: https://github.com/eric-haibin-lin/verl-community/tree/main/docs/loss_comparison.png + :alt: loss_comparison + +其中,N 表示训练的步数。更多信息请参考[精度计算说明](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html)。 + +根据经验,对于GRPO等强化学习算法,我们期望在相同配置下,在华为昇腾设备上的 reward 与英伟达 GPU 的 reward 平均绝对误差小于等于 4%,具体计算参考 Loss 计算。 + +进展 +------ + ++--------+--------+ +| 算法 | 进展 | ++--------+--------+ +| SFT | 已支持 | ++--------+--------+ +| GRPO | 已支持 | ++--------+--------+ diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index 27df2cc7438..6d0f4a721ea 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -21,6 +21,7 @@ import ray from .dapo_ray_trainer import RayDAPOTrainer +from verl.utils.device import is_cuda_available def get_custom_reward_fn(config): diff --git a/recipe/sppo/main_sppo.py b/recipe/sppo/main_sppo.py index eae1e43e343..25b1c469e7c 100644 --- a/recipe/sppo/main_sppo.py +++ b/recipe/sppo/main_sppo.py @@ -25,6 +25,7 @@ from verl.trainer.ppo.reward import load_reward_manager from .sppo_ray_trainer import RaySPPOTrainer +from verl.utils.device import is_cuda_available @hydra.main(config_path="config", config_name="sppo_trainer", version_base=None) @@ -140,6 +141,7 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name="cuda" if is_cuda_available else "npu", ) trainer.init_workers() trainer.fit() diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index 0e870a0facd..761def940bc 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -86,6 +86,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + device_name="cuda", ): self.tokenizer = tokenizer self.processor = processor @@ -105,6 +106,7 @@ def __init__( self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() + self.device_name = device_name # define in-reward KL control # kl loss control currently not suppoorted diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 00000000000..601e8f9fa6e --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,20 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +hydra-core +numpy +pandas +peft +pyarrow>=15.0.0 +pybind11 +pylatexenc +ray +tensordict<=0.6.2 +transformers>=4.52.0 +wandb +mathruler +torchdata +einops +qwen_vl_utils diff --git a/tests/npu/run_qwen2_5_05b_grpo.sh b/tests/npu/run_qwen2_5_05b_grpo.sh new file mode 100644 index 00000000000..ed44063d59d --- /dev/null +++ b/tests/npu/run_qwen2_5_05b_grpo.sh @@ -0,0 +1,42 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=512 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_32b_grpo.sh b/tests/npu/run_qwen2_5_32b_grpo.sh new file mode 100644 index 00000000000..d83e36b843f --- /dev/null +++ b/tests/npu/run_qwen2_5_32b_grpo.sh @@ -0,0 +1,43 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6\ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_32b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_7b_grpo.sh b/tests/npu/run_qwen2_5_7b_grpo.sh new file mode 100644 index 00000000000..8ee7445b469 --- /dev/null +++ b/tests/npu/run_qwen2_5_7b_grpo.sh @@ -0,0 +1,44 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 $@ \ No newline at end of file diff --git a/verl/__init__.py b/verl/__init__.py index 9f4fa70d77f..d1b8547bca3 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -14,9 +14,13 @@ import logging import os +import pkg_resources +from pkg_resources import DistributionNotFound +from packaging.version import parse as parse_version from .protocol import DataProto from .utils.logging_utils import set_basic_config +from .utils.device import is_npu_available version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) @@ -38,3 +42,17 @@ from modelscope.utils.hf_util import patch_hub patch_hub() + +if is_npu_available: + package_name = 'transformers' + required_version_spec = '4.51.0' + try: + installed_version = pkg_resources.get_distribution(package_name).version + installed = parse_version(installed_version) + required = parse_version(required_version_spec) + + if not installed >= required: + raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.") + except DistributionNotFound: + raise ImportError( + f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}") diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 91112ca6029..f306d26ac8d 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -33,7 +33,7 @@ ) try: - from flash_attn import flash_attn_func, flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: diff --git a/verl/protocol.py b/verl/protocol.py index 5b729134cec..64682a4d469 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -36,6 +36,7 @@ from verl.utils.py_functional import union_two_dict from verl.utils.torch_functional import allgather_dict_tensors +from verl.utils.device import get_torch_device __all__ = ["DataProto", "union_tensor_dict"] @@ -272,7 +273,7 @@ def __setstate__(self, data): batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) - batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None) + batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None) self.batch = batch self.non_tensor_batch = non_tensor_batch self.meta_info = meta_info @@ -802,7 +803,7 @@ def all_gather_data_proto(data: DataProto, process_group): group_size = torch.distributed.get_world_size(group=process_group) assert isinstance(data, DataProto) prev_device = data.batch.device - data.batch = data.batch.cuda(device=torch.cuda.current_device()) + data.batch = data.batch.to(get_torch_device().current_device()) data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 5305fd6fb38..7be25fbe087 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -23,6 +23,7 @@ import ray from .decorator import Dispatch, Execute, register +from verl.utils.device import get_torch_device @dataclass @@ -147,10 +148,9 @@ def __init__(self, cuda_visible_devices=None) -> None: import torch from packaging import version - ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES") os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK") ### @@ -168,7 +168,7 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): self.local_rank = int(os.environ["LOCAL_RANK"]) cuda_visible_devices = str(local_rank) ### @@ -188,8 +188,8 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): - torch.cuda.set_device(int(cuda_visible_devices)) + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + get_torch_device().set_device(int(cuda_visible_devices)) ### self.fused_worker_dict = {} diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index c0822e3cf27..ed086fc797e 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -95,13 +95,17 @@ def __init__( self.pgs = None self.detached = detached - def get_placement_groups(self, strategy="STRICT_PACK", name=None): + def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): if self.pgs is not None: return self.pgs pg_name_prefix = name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" # print(f"pg_name_prefix = {pg_name_prefix}") - pg_scheme = [[{"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] + if device_name == "npu": + device_name = "NPU" + elif device_name == "cuda": + device_name = "GPU" + pg_scheme = [[{"CPU": self.max_colocate_count, device_name: 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] lifetime = "detached" if self.detached else None @@ -174,7 +178,7 @@ def update_options(self, options: Dict): """ self._options.update(options) - def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None) -> Any: + def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None, device_name="cuda") -> Any: """Create and return a Ray actor with the configured options. Args: @@ -183,6 +187,7 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = use_gpu: Whether to use GPU resources num_gpus: Number of GPUs to allocate sharing_with: Actor to share resources with + device_name: Device for training Returns: A Ray actor handle with the configured options @@ -196,8 +201,10 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = options = {"scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx)} options.update(self._options) - if use_gpu: + if use_gpu and device_name == "cuda": options["num_gpus"] = num_gpus + if use_gpu and device_name == "npu": + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -227,6 +234,7 @@ def __init__( worker_names=None, worker_handles: List[ray.actor.ActorHandle] = None, ray_wait_register_center_timeout: int = 300, + device_name="cuda", **kwargs, ) -> None: """Initialize a RayWorkerGroup. @@ -249,6 +257,7 @@ def __init__( self.fused_worker_used = ray_cls_with_init.fused_worker_used # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to this WorkerGroup. self.sub_cls_name = "" + self.device_name = device_name if worker_names is not None and (not self.fused_worker_used): assert self._is_init_with_detached_workers @@ -300,7 +309,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d strategy = "PACK" if bin_pack: strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy) + pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size # cia.add_kwarg("_world_size", world_size) @@ -339,7 +348,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d ray_cls_with_init.update_options({"lifetime": "detached"}) # create a worker - worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus) + worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus, device_name=self.device_name) self._workers.append(worker) self._worker_names.append(name) diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 633f2f4f9b4..018b5ca3662 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -30,7 +30,6 @@ import hydra import torch import torch.distributed -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from peft import LoraConfig, TaskType, get_peft_model from tensordict import TensorDict from torch import nn, optim @@ -55,8 +54,15 @@ get_ulysses_sequence_parallel_world_size, ulysses_pad_and_slice_inputs, ) +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) @@ -108,6 +114,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) + self.device_name = get_device_name() def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) @@ -244,7 +251,7 @@ def _build_model_optimizer(self): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), cpu_offload=cpu_offload, use_orig_params=False, ) @@ -280,15 +287,15 @@ def _compute_loss_and_backward(self, batch, do_backward=True): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch["input_ids"].cuda() - attention_mask = batch["attention_mask"].cuda() - position_ids = batch["position_ids"].cuda() - loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda() + input_ids = batch["input_ids"].to(self.device_name) + attention_mask = batch["attention_mask"].to(self.device_name) + position_ids = batch["position_ids"].to(self.device_name) + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) loss_fct = nn.CrossEntropyLoss(reduction="none") # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -398,15 +405,23 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage("After offload weights", logger=logger) - step_loss = torch.tensor(step_loss).cuda() - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} + step_loss = torch.tensor(step_loss).to(self.device_name) + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.ulysses_device_mesh.size(0) + return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): loss = self._compute_loss_and_backward(batch, do_backward=False) - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.ulysses_device_mesh.size(0) return loss def save_checkpoint(self, step): @@ -461,7 +476,7 @@ def fit(self): desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", ): global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -471,7 +486,7 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -487,7 +502,7 @@ def fit(self): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -502,11 +517,12 @@ def fit(self): @hydra.main(config_path="config", config_name="sft_trainer", version_base=None) def main(config): + device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) # build tokenizer and datasets first from verl.utils import hf_tokenizer diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 0d12b5b4b92..0f80b7caf91 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -38,6 +38,7 @@ from verl.utils.hdfs_io import makedirs from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.utils.device import is_cuda_available @hydra.main(config_path="config", config_name="generation", version_base=None) @@ -81,7 +82,7 @@ def main_task(config): ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu") wg.init_model() total_samples = len(dataset) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 4e004fbff31..77e20343993 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -22,6 +22,7 @@ from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available def get_custom_reward_fn(config): @@ -178,6 +179,7 @@ def run(self, config): val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, + device_name="cuda" if is_cuda_available else "npu", ) trainer.init_workers() trainer.fit() diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 56bc5b75f91..91117780f07 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -123,7 +123,7 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} + node_available_gpus = {node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) for node, node_info in node_available_resources.items()} # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) @@ -284,9 +284,8 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + device_name="cuda", ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' - self.tokenizer = tokenizer self.processor = processor self.config = config @@ -304,6 +303,7 @@ def __init__( self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name self.validation_generations_logger = ValidationGenerationsLogger() # define in-reward KL control @@ -716,7 +716,7 @@ def init_workers(self): for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index 1914d475513..c9ac414f370 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -22,6 +22,7 @@ import torch.distributed from filelock import FileLock from transformers import PreTrainedTokenizer, ProcessorMixin +from verl.utils.device import is_cuda_available, is_npu_available class BaseCheckpointManager: @@ -107,18 +108,27 @@ def local_mkdir(path): def get_rng_state(): rng_state = { "cpu": torch.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), "numpy": np.random.get_state(), "random": random.getstate(), } + + if is_cuda_available: + rng_state["cuda"] = torch.cuda.get_rng_state() + elif is_npu_available: + rng_state["npu"] = torch.npu.get_rng_state() + return rng_state @staticmethod def load_rng_state(rng_state): torch.set_rng_state(rng_state["cpu"]) - torch.cuda.set_rng_state(rng_state["cuda"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) + + if is_cuda_available: + torch.cuda.set_rng_state(rng_state["cuda"]) + elif is_npu_available: + torch.npu.set_rng_state(rng_state["npu"]) def find_latest_ckpt_path(path, directory_format="global_step_{}"): diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index bf9206d5498..b556f298412 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -24,6 +24,7 @@ from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin from verl.utils.fs import copy_to_local, is_non_local +from verl.utils.device import is_cuda_available from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx from .checkpoint_manager import BaseCheckpointManager @@ -96,8 +97,8 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: @@ -127,8 +128,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.distributed.barrier() # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with warnings.catch_warnings(): warnings.simplefilter("ignore") with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index f461b3f47a8..fd1b7c40d1c 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -19,18 +19,19 @@ import torch.distributed as dist from verl.utils.logger.aggregate_logger import DecoratorLoggerBase +from verl.utils.device import get_torch_device def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: """Get current memory usage.""" assert unit in ["GB", "MB", "KB"] divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 - mem_allocated = torch.cuda.memory_allocated() - mem_reserved = torch.cuda.memory_reserved() - # use torch.cuda.mem_get_info to profile device memory + mem_allocated = get_torch_device().memory_allocated() + mem_reserved = get_torch_device().memory_reserved() + # use get_torch_device().mem_get_info to profile device memory # since vllm's sleep mode works below pytorch # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 - mem_free, mem_total = torch.cuda.mem_get_info() + mem_free, mem_total = get_torch_device().mem_get_info() mem_used = mem_total - mem_free mem_allocated = f"{mem_allocated / divisor:.{precision}f}" mem_reserved = f"{mem_reserved / divisor:.{precision}f}" diff --git a/verl/utils/device.py b/verl/utils/device.py new file mode 100644 index 00000000000..ee9e279d212 --- /dev/null +++ b/verl/utils/device.py @@ -0,0 +1,57 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 7aa30c16815..101d972b6b0 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -14,6 +14,7 @@ """Utilities for distributed training.""" import os +from verl.utils.device import is_cuda_available, get_torch_device def initialize_global_process_group(timeout_second=36000): @@ -21,11 +22,11 @@ def initialize_global_process_group(timeout_second=36000): import torch.distributed - torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group("nccl" if is_cuda_available else "hccl", timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + get_torch_device().set_device(local_rank) return local_rank, rank, world_size diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 5eefbb7400e..c25bec89b2a 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig +from verl.utils.device import is_cuda_available, get_torch_device VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "deepseek_v3"} @@ -29,7 +30,7 @@ def unit_convert(number, level): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = get_torch_device().get_device_name() flops = float("inf") # INF flops for unkown gpu type if "MI300X" in device_name: diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index c645cfe9787..53af8798736 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -29,6 +29,7 @@ from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name +from verl.utils.device import get_torch_device, get_device_name if version.parse(torch.__version__) >= version.parse("2.6"): from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard @@ -40,8 +41,8 @@ def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=get_torch_device().current_device(), recurse=False) + get_torch_device().empty_cache() return x @@ -144,7 +145,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -152,7 +153,7 @@ def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): for param in model.parameters(): param.data = param.data.to(torch.device("cpu"), non_blocking=True) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -165,12 +166,12 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model loading to GPU" - device_id = torch.cuda.current_device() + device_id = get_torch_device().current_device() for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True) + handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -279,7 +280,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = get_torch_device().current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -317,7 +318,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = get_torch_device().current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 7ae5176bbc8..af9730b7f33 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -23,7 +23,6 @@ from typing import Tuple import torch -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -37,6 +36,13 @@ from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis + __all__ = ["DataParallelPPOActor"] @@ -64,6 +70,7 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim if self.config.get("use_torch_compile", True) # use torch compile by default else verl_F.entropy_from_logits ) + self.device_name = get_device_name() if self.use_fused_kernels: from verl.utils.experimental.torch_functional import FusedLinearForPPO @@ -86,7 +93,7 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -359,9 +366,9 @@ def update_policy(self, data: DataProto): for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload + data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index dc83028a620..a28e18573f6 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -34,8 +34,13 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic +from verl.utils.device import get_device_name, get_torch_device, is_npu_available, is_cuda_available -__all__ = ["DataParallelPPOCritic"] + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -50,6 +55,7 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt print(f"Critic use_remove_padding={self.use_remove_padding}") self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch["responses"].size(-1) @@ -58,7 +64,7 @@ def _forward_micro_batch(self, micro_batch): for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -207,9 +213,9 @@ def update_critic(self, data: DataProto): for data in micro_batches: # Support all devices if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload + data = data.to(get_torch_device().current_device()) # critic device is cpu when using offload responses = data["responses"] attention_mask = data["attention_mask"] values = data["values"] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 48114746b67..ae6e28faa93 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -54,16 +54,20 @@ from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +device_name = get_device_name() + def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) return device_mesh @@ -93,7 +97,7 @@ def __init__(self, config: DictConfig, role: str): if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) - torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", rank=rank, world_size=world_size) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -105,7 +109,7 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -279,7 +283,7 @@ def _build_model_optimizer( param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -354,7 +358,7 @@ def _build_rollout(self, trust_remote_code=False): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -566,13 +570,13 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) @@ -583,8 +587,8 @@ def update_actor(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] @@ -609,7 +613,7 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) + prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout @@ -639,7 +643,7 @@ def generate_sequences(self, prompts: DataProto): output = output.to("cpu") # clear kv cache - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -649,7 +653,7 @@ def compute_log_prob(self, data: DataProto): load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -683,7 +687,7 @@ def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -740,7 +744,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -754,7 +758,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -871,7 +875,7 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -959,7 +963,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -982,11 +986,11 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -1056,7 +1060,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -1070,7 +1074,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -1135,7 +1139,7 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -1164,11 +1168,14 @@ def init_model(self): self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -1289,7 +1296,7 @@ def compute_rm_score(self, data: DataProto): from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: @@ -1304,7 +1311,7 @@ def compute_rm_score(self, data: DataProto): rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) + rm_data.batch = rm_data.batch.to(get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py index 7b04e51ef66..ff34407d77e 100644 --- a/verl/workers/rollout/hf_rollout.py +++ b/verl/workers/rollout/hf_rollout.py @@ -29,6 +29,7 @@ from verl import DataProto from verl.utils.torch_functional import get_response_mask +from verl.utils.device import get_torch_device from .base import BaseRollout @@ -165,7 +166,7 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: ) # empty cache before compute old_log_prob - torch.cuda.empty_cache() + get_torch_device().empty_cache() self.module.train() return DataProto(batch=batch) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 114b9646fb2..133db24ef84 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -35,6 +35,7 @@ from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import check_cuda_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from verl.utils.device import get_torch_device from .base import BaseShardingManager @@ -84,26 +85,26 @@ def __init__( self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __enter__(self): - # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and + # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - torch.cuda.empty_cache() + get_torch_device().empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: @@ -132,7 +133,7 @@ def __enter__(self): del params if self.offload_param: offload_fsdp_model_to_cpu(self.module) - torch.cuda.empty_cache() + get_torch_device().empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) @@ -141,8 +142,8 @@ def __enter__(self): # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): @@ -158,12 +159,12 @@ def __exit__(self, exc_type, exc_value, traceback): self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: @@ -194,6 +195,6 @@ def postprocess_data(self, data: DataProto) -> DataProto: def update_params(self, updated_params): model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) - device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy + device = get_torch_device().current_device() # used when fsdp2 set cpu_offload_policy loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items())) logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))