diff --git a/.github/workflows/checkpoint_converter.yml b/.github/workflows/checkpoint_converter.yml index 639b63b6524..cea6dbf16e3 100644 --- a/.github/workflows/checkpoint_converter.yml +++ b/.github/workflows/checkpoint_converter.yml @@ -14,15 +14,22 @@ on: - v0.* paths: - "**/*.py" - # Entrypoints - - ".github/workflows/checkpoint_converter.yml" - - "!examples" + # Other entrypoints + - "!examples/**" + - "!tests/**" - "!verl/trainer/main_*.py" - "!verl/trainer/fsdp_sft_trainer.py" # Recipes - - "!recipe" + - "!recipe/**" # FSDP - "!verl/workers/**/*dp_*.py" + # Entrypoints + - ".github/workflows/checkpoint_converter.yml" + - ".github/workflows/e2e_ppo_trainer_megatron.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/e2e/run_ppo_trainer_megatron.sh" + - "verl/trainer/main_ppo.py" + - "verl/trainer/config/ppo_megatron_trainer.yaml" # Cancel jobs on the same ref if a new one is triggered diff --git a/.github/workflows/e2e_prime.yml b/.github/workflows/disabled/e2e_prime.yml similarity index 97% rename from .github/workflows/e2e_prime.yml rename to .github/workflows/disabled/e2e_prime.yml index 50c1c0a37cd..61c7e86cfb9 100644 --- a/.github/workflows/e2e_prime.yml +++ b/.github/workflows/disabled/e2e_prime.yml @@ -5,12 +5,10 @@ on: # but only for the main branch push: branches: - - main - - v0.* + - disabled_ci pull_request: branches: - - main - - v0.* + - disabled_ci paths: - "**/*.py" # Other entrypoints diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index b80e5f1d089..456b72a1510 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -26,9 +26,9 @@ jobs: test: name: verl Ascend test (self-host) runs-on: [self-hosted, npu-0] - timeout-minutes: 5 # Increase this timeout value as needed + timeout-minutes: 30 # Increase this timeout value as needed container: - image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10 + image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 volumes: - /usr/local/dcmi:/usr/local/dcmi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi @@ -42,6 +42,13 @@ jobs: --device /dev/hisi_hdc --privileged --network "host" + --shm-size 2g + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable steps: - name: Check npu and CANN info run: | @@ -49,6 +56,42 @@ jobs: npu-smi info - name: Checkout volcengine/verl repo uses: actions/checkout@v4 - - name: Run test + - name: Install torch run: | - lscpu + pip install torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu + pip install torch-npu==2.5.1 + pip install /usr/local/Ascend/ascend-toolkit/latest/lib64/te-0.4.0-py3-none-any.whl + - name: Install vllm + run: | + apt-get update && apt-get install -y git + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git vllm-npu + cd vllm-npu + pip install -r requirements-build.txt + VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ + - name: Install vllm-ascend + run: | + pip list + pip show torch + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + cd vllm-ascend + export COMPILE_CUSTOM_KERNELS=1 + python setup.py install + - name: Install the current repository + run: | + pip3 install hf_transfer peft + pip3 install -r requirements-npu.txt + pip install -e . + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e training tests with LoRA on ASCEND NPU + run: | + ray stop --force + bash tests/e2e/sft/run_sft.sh + rm -rf $HOME/ckpts + - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU + run: | + ray stop --force + bash tests/npu/run_qwen2_5_05b_grpo.sh + rm -rf $HOME/ckpts \ No newline at end of file diff --git a/.github/workflows/e2e_dapo.yml b/.github/workflows/e2e_dapo.yml index 9698d51cbdc..784e2a071c6 100644 --- a/.github/workflows/e2e_dapo.yml +++ b/.github/workflows/e2e_dapo.yml @@ -23,7 +23,7 @@ on: # Megatron - "!verl/workers/**/megatron_*.py" # Home - - "recipe/dapo/src" + - "recipe/dapo" # Entrypoints - ".github/workflows/e2e_dapo.yml" - "examples/data_preprocess/gsm8k.py" @@ -34,7 +34,6 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - # Declare permissions just read content. permissions: contents: read diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index b9714ea4378..f6ab375363c 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -61,7 +61,7 @@ jobs: e2e_ppo_trainer_vllm: runs-on: [L20x8] - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -148,6 +148,14 @@ jobs: run: | ray stop --force LIGER=True bash tests/e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNELS=True bash tests/e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNEL=True FUSED_KERNEL_BACKEND=triton bash tests/e2e/ppo_trainer/run_model_reward.sh e2e_ppo_trainer_vllm_vlm: runs-on: [L20x8] @@ -182,6 +190,27 @@ jobs: MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + run: | + ray stop --force + FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ + bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + run: | + ray stop --force + FUSED_KERNELS=True FUSED_KERNEL_BACKEND=triton \ + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ + bash tests/e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang: runs-on: [L20x8] @@ -269,11 +298,15 @@ jobs: run: | ray stop --force bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh + - name: Running GSM8K with tool E2E training tests with FSDP2 + run: | + ray stop --force + FSDP_STRATEGY=fsdp2 bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh e2e_ppo_trainer_sglang_vlm: runs-on: [L20x8] needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -305,3 +338,24 @@ jobs: ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + run: | + ray stop --force + FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ + bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + run: | + ray stop --force + FUSED_KERNELS=True FUSED_KERNEL_BACKEND=triton \ + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ + bash tests/e2e/ppo_trainer/run_function_reward.sh \ No newline at end of file diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 42ad40207d3..f9dc924483e 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -40,51 +40,9 @@ permissions: contents: read jobs: - e2e_ppo_trainer_megatron-qwen: - runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with validation and saving - run: | - ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) after resuming - run: | - ray stop --force - RESUME_MODE=auto bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) - run: | - exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path Qwen/Qwen2.5-0.5B - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path Qwen/Qwen2.5-0.5B - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) - run: | - ray stop --force - ADV_ESTIMATOR=grpo bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints e2e_ppo_trainer_megatron-deepseek: runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -107,26 +65,26 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) run: | ray stop --force - SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh + ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) run: | ray stop --force - RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek) - run: | - ray stop --force - ADV_ESTIMATOR=grpo MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh + RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 bash tests/e2e/run_ppo_trainer_megatron.sh - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) run: | exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek) + run: | + ray stop --force + ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints e2e_ppo_trainer_megatron-qwen3: runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -134,7 +92,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -149,7 +107,7 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming run: | ray stop --force @@ -157,51 +115,18 @@ jobs: - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) run: | exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) run: | ray stop --force - ADV_ESTIMATOR=grpo MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-different-train-infer-tp-qwen: - runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp > infer tp - run: | - ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=2 INFER_TP=1 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp < infer tp - run: | - ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=1 INFER_TP=2 bash tests/e2e/run_ppo_trainer_megatron.sh + ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding: runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -234,7 +159,7 @@ jobs: rm -rf checkpoints e2e_ppo_trainer_megatron-qwen-override-transformer-config: runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -266,14 +191,14 @@ jobs: - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) run: | exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path Qwen/Qwen2.5-0.5B - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path Qwen/Qwen2.5-0.5B + python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: clean up run: | rm -rf checkpoints e2e_ppo_trainer_megatron-deepseek-override-transformer-config: runs-on: [L20x8] - timeout-minutes: 30 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -300,8 +225,8 @@ jobs: - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) run: | exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: clean up run: | rm -rf checkpoints diff --git a/.github/workflows/e2e_spin.yml b/.github/workflows/e2e_spin.yml index 5ed75ff6bd6..0ec51115f88 100644 --- a/.github/workflows/e2e_spin.yml +++ b/.github/workflows/e2e_spin.yml @@ -13,6 +13,15 @@ on: - v0.* paths: - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" # Home - "recipe/spin" # Entrypoints @@ -20,10 +29,6 @@ on: - "examples/data_preprocess/gsm8k.py" - "tests/e2e/run_spin.sh" - "!examples" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Megatron - - "!verl/workers/**/megatron_*.py" # Declare permissions just read content. permissions: @@ -58,4 +63,4 @@ jobs: - name: Running the E2E test with the spin algorithm run: | ray stop --force - bash tests/e2e/run_spin.sh \ No newline at end of file + bash tests/e2e/run_spin.sh diff --git a/.github/workflows/e2e_sppo.yml b/.github/workflows/e2e_sppo.yml index 061450ff1d4..d2ee8fe8913 100644 --- a/.github/workflows/e2e_sppo.yml +++ b/.github/workflows/e2e_sppo.yml @@ -13,17 +13,21 @@ on: - v0.* paths: - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" # Home - "recipe/sppo" # Entrypoints - ".github/workflows/e2e_sppo.yml" - "examples/data_preprocess/gsm8k.py" - "tests/e2e/run_sppo.sh" - - "!examples" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Megatron - - "!verl/workers/**/megatron_*.py" # Declare permissions just read content. permissions: diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml index 1e049e71366..053419941b5 100644 --- a/.github/workflows/kernels.yml +++ b/.github/workflows/kernels.yml @@ -17,9 +17,16 @@ on: - v0.2.x paths: - "**/*.py" - - "verl/trainer/config/*.yaml" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # Entrypoints - .github/workflows/kernels.yml - - "tests/e2e/*.sh" + - "tests/kernels/*" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -52,4 +59,7 @@ jobs: pip3 install --no-deps -e .[test] - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption run: | - python3 tests/kernels/test_linear_cross_entropy.py \ No newline at end of file + python3 tests/kernels/p.py + - name: Testing LinearCrossEntropyTP Correction, Computation Time and Memory Consumption + run: | + torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/kernels/test_linear_cross_entropy_tp.py \ No newline at end of file diff --git a/README.md b/README.md index 3a00a36395d..a61f5ce3102 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,7 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The - [Code-R1](https://github.com/ganler/code-r1): Reproducing R1 for **Code** with Reliable Rewards ![GitHub Repo stars](https://img.shields.io/github/stars/ganler/code-r1) - [Skywork-OR1](https://github.com/SkyworkAI/Skywork-OR1): Skywork open reaonser series ![GitHub Repo stars](https://img.shields.io/github/stars/SkyworkAI/Skywork-OR1) - [ToRL](https://github.com/GAIR-NLP/ToRL): Scaling tool-integrated RL ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/ToRL) +- [verl-agent](https://github.com/langfengQ/verl-agent): A scalable training framework for **long-horizon LLM/VLM agents**, along with a new algorithm **GiGPO** ![GitHub Repo stars](https://img.shields.io/github/stars/langfengQ/verl-agent) - [GUI-R1](https://github.com/ritzz-ai/GUI-R1): **GUI-R1**: A Generalist R1-style Vision-Language Action Model For **GUI Agents** ![GitHub Repo stars](https://img.shields.io/github/stars/ritzz-ai/GUI-R1) - [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling deep research via reinforcement learning in real-world environments ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher) - [VAGEN](https://github.com/RAGEN-AI/VAGEN): Training VLM agents with multi-turn reinforcement learning ![GitHub Repo stars](https://img.shields.io/github/stars/RAGEN-AI/VAGEN) diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index f10b6521c87..369e59776c7 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -22,5 +22,7 @@ Core APIs .. autoclass:: verl.single_controller.ResourcePool :members: __init__, world_size, local_world_size_list, local_rank_list -.. automodule:: verl.single_controller.ray - :members: RayWorkerGroup, create_colocated_worker_cls \ No newline at end of file +.. autoclass:: verl.single_controller.ray.RayWorkerGroup + :members: __init__ + +.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls \ No newline at end of file diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index d890b7341c6..cd308c44d09 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -1,5 +1,5 @@ -Trainers -========================= +Trainer Interface +================================ Trainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged. @@ -13,9 +13,16 @@ Core APIs ~~~~~~~~~~~~~~~~~ .. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer + :members: __init__, init_workers, fit + .. automodule:: verl.utils.tokenizer :members: hf_tokenizer -.. automodule:: verl.single_controller - :members: Worker, WorkerGroup, ClassWithInitArgs, ResourcePool + +.. automodule:: verl.trainer.ppo.core_algos + :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty + + +.. automodule:: verl.trainer.ppo.reward + :members: load_reward_manager, compute_reward, compute_reward_async diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 5caf23d1ad6..3ac4380b039 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -1,8 +1,74 @@ -Training utils -========================= +Utilities +============ -Core APIs -~~~~~~~~~~~~~~~~~ +This section documents the utility functions and classes in the VERL library. + +Python Functional Utilities +------------------------------ + +.. automodule:: verl.utils.py_functional + :members: append_to_dict + +File System Utilities +------------------------ + +.. automodule:: verl.utils.fs + :members: copy_to_local + +Tracking Utilities +--------------------- + +.. automodule:: verl.utils.tracking + :members: Tracking + +Metrics Utilities +--------------------- .. automodule:: verl.utils.metric :members: reduce_metrics + +Checkpoint Management +------------------------ + +.. automodule:: verl.utils.checkpoint.checkpoint_manager + :members: find_latest_ckpt_path + +.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager + :members: FSDPCheckpointManager + +Dataset Utilities +--------------------- + +.. automodule:: verl.utils.dataset.rl_dataset + :members: RLHFDataset, collate_fn + +Torch Functional Utilities +----------------------------- + +.. automodule:: verl.utils.torch_functional + :members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits + +Sequence Length Balancing +---------------------------- + +.. automodule:: verl.utils.seqlen_balancing + :members: get_reverse_idx, rearrange_micro_batches + +Ulysses Utilities +-------------------- + +.. automodule:: verl.utils.ulysses + :members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + +FSDP Utilities +------------------ + +.. automodule:: verl.utils.fsdp_utils + :members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, + +Debug Utilities +------------------- + +.. automodule:: verl.utils.debug + :members: log_gpu_memory_usage, GPUMemoryLogger + diff --git a/docs/ascend_tutorial/ascend_quick_start.rst b/docs/ascend_tutorial/ascend_quick_start.rst new file mode 100644 index 00000000000..f65f427ff09 --- /dev/null +++ b/docs/ascend_tutorial/ascend_quick_start.rst @@ -0,0 +1,183 @@ +verl x Ascend +=================================== + + +我们在 verl 上增加对华为昇腾设备的支持。 + +硬件支持 +----------------------------------- + +Atlas 200T A2 Box16 + +Atlas 800T A2 + + +安装 +----------------------------------- + +基础环境准备 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ++-----------+-------------+ +| software | version | ++-----------+-------------+ +| Python | == 3.10 | ++-----------+-------------+ +| CANN | == 8.1.RC1 | ++-----------+-------------+ +| torch | == 2.5.1 | ++-----------+-------------+ +| torch_npu | == 2.5.1.RC1| ++-----------+-------------+ + + +vllm & vllm-ascend +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +为了能够在 verl 中正常使用 vllm,需使用以下命令编译安装 vllm 和 vllm-ascend。请注意根据机器类型区分安装方式。 + +.. code-block:: bash + + # vllm + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git + cd vllm + pip install -r requirements-build.txt + + # for Atlas 200T A2 Box16 + VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ + + # for Atlas 800T A2 + VLLM_TARGET_DEVICE=empty pip install -e . + +.. code-block:: bash + + # vllm-ascend + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + cd vllm-ascend + export COMPILE_CUSTOM_KERNELS=1 + python setup.py install + +安装verl +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + git clone https://github.com/volcengine/verl.git + cd verl + pip install -r requirements-npu.txt + pip install -e . + +其他三方库说明 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ++--------------+---------------+ +| software | description | ++--------------+---------------+ +| transformers | >= v4.52.0 | ++--------------+---------------+ +| flash_attn | not supported | ++--------------+---------------+ +| liger-kernel | not supported | ++--------------+---------------+ + +1. 支持通过 transformers 使能 --flash_attention_2, transformers 需大于等于 4.52.0版本。 +2. 不支持通过 flash_attn 使能 flash attention 加速。 +3. 不支持 liger-kernel 使能。 + + +快速开始 +----------------------------------- +正式使用前,建议您通过对Qwen2.5-0.5B GRPO的训练尝试以检验环境准备和安装的正确性。 + +.. code-block:: bash + + set -x + + export VLLM_ATTENTION_BACKEND=XFORMERS + + python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=512 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 \ + trainer.device=npu $@ + + +支持现状 +----------------------------------- + ++-----------+----------------------+-------------+-------------------+----------------------+ +| algorithm | model | rewards mae | throughput ratio | hardware | ++-----------+----------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-7B-instruct | 0.38% | 0.588 | Atlas 200T A2 Box16 | ++-----------+----------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-32B-instruct | 0.30% | 0.685 | Atlas 200T A2 Box16 | ++-----------+----------------------+-------------+-------------------+----------------------+ + +目前支持 Qwen2.5 的 GRPO 训练,Qwen2.5-VL GRPO 训练在 vllm-ascend 的修复后支持,涉及到的issue为: + +1. `issues#809 `_ + +2. `issues#825 `_ + + +精度对比说明 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +对于 SFT 类算法,我们期望在相同配置下华为昇腾设备与 A100 的 loss 平均绝对误差<= 2%。计算方式如下图。更多信息请参考 `精度计算说明 `_。 + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/loss_comparison.png?raw=true + :alt: loss_comparison + +根据经验,对于 GRPO 等 RL 类算法,我们期望在相同配置下华为昇腾设备与 A100 的 rewards 平均绝对误差<= 4%,计算方式参考上图。 + + +吞吐对比说明 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Ascend npu 和 A100 分别取日志中前4个 step 的 "perf/throughput" 做平均, throughput ratio = npu 平均值 / A100 平均值。 + + + +计划 +----------------------------------- + +查看 `roadmap `_ 获取更多特性的支持进度。 + + + +声明 +----------------------------------- +verl中提供的ascend支持代码皆为参考样例,商业使用请通过官方正式途径沟通,谢谢。 \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index fe8cf2a5dbf..829a5ed8e71 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,7 +48,11 @@ "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.autosectionlabel", + "sphinx.ext.napoleon", ] +# Use Google style docstrings instead of NumPy docstrings. +napoleon_google_docstring = True +napoleon_numpy_docstring = False # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 0541c3dc17f..2f27e448792 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -97,6 +97,7 @@ Actor/Rollout/Reference Policy moe_config: # Megatron only, can adjust moe configuration freeze_moe_router: False # Megatron only, can freeze moe router (no grad) enable_gradient_checkpointing: False + enable_activation_offload: False trust_remote_code: False use_remove_padding: False actor: @@ -197,6 +198,8 @@ Actor/Rollout/Reference Policy the model's original configurations, mainly dropout - ``actor_rollout_ref.model.enable_gradient_checkpointing``: Whether to enable gradient checkpointing for the actor +- ``actor_rollout_ref.model.enable_activation_offload``: Whether to enable + activation offloading for the actor - ``actor_rollout_ref.model.trust_remote_code``: Whether to enable loading a remote code model @@ -506,6 +509,13 @@ Trainer for the ray register center to be ready. Default is 300 seconds. +This figure illustrates how the configurations affect the training. + +https://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA + +.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d + + evaluation.yaml --------------- diff --git a/docs/faq/faq.rst b/docs/faq/faq.rst index 5cd555fd481..c836b0613fc 100644 --- a/docs/faq/faq.rst +++ b/docs/faq/faq.rst @@ -107,6 +107,8 @@ https://verl.readthedocs.io/en/latest/examples/config.html to disable just-in-ti What is the meaning of train batch size, mini batch size, and micro batch size? ------------------------------------------------------------------------------------------ -Please check out the following figure from the community (credit to @hiyouga) +This figure illustrates the relationship between different batch size configurations. -.. image:: https://github.com/hiyouga/EasyR1/blob/main/assets/easyr1_grpo.png +https://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA + +.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d diff --git a/docs/index.rst b/docs/index.rst index 308051084da..8f9c0adc308 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,8 +108,9 @@ verl is fast with: :caption: API References api/data - api/utils api/single_controller.rst + api/trainer.rst + api/utils.rst .. toctree:: diff --git a/docs/perf/perf_tuning.rst b/docs/perf/perf_tuning.rst index 1b9ae1b0383..bab3dc29dd0 100644 --- a/docs/perf/perf_tuning.rst +++ b/docs/perf/perf_tuning.rst @@ -106,6 +106,9 @@ Therefore, users may need to tune the ``*micro_batch_size_per_gpu`` to accelerat 4. **Allow larger micro-batch sizes for Critic and Reward models**: micro batch size of Critic and Reward model could be larger than Actor model. This is because the actor model has much larger vocab size in the final layer. +5. **Enable activation offloading**: + Set ``actor_rollout_ref.model.enable_activation_offload=True`` and ``critic.model.enable_activation_offload=True``. + This often works together with gradient checkpointing to get larger micro-batch sizes and it's only available in FSDP backend now. Tuning for Dynamic Batch Size ----------------------------- diff --git a/docs/start/multinode.rst b/docs/start/multinode.rst index e278840956b..6caa53c3b29 100644 --- a/docs/start/multinode.rst +++ b/docs/start/multinode.rst @@ -71,6 +71,124 @@ Slurm ----- TBD +dstack +------ +`dstackai/dstack `_ is an open-source container orchestrator that simplifies distributed training across cloud providers and on-premises environments +without the need to use K8S or Slurm. + +Prerequisite +~~~~~~~~~~~~ +Once dstack is `installed `_, initialize the directory as a repo with ``dstack init``. + +.. code-block:: bash + + mkdir myproject && cd myproject + dstack init + +**Create a fleet** + +Before submitting distributed training jobs, create a `dstack` `fleet `_. + +Run a Ray cluster task +~~~~~~~~~~~~~~~~~~~~~~ + +Once the fleet is created, define a Ray cluster task, e.g. in ``ray-cluster.dstack.yml``: + +.. code-block:: yaml + + type: task + name: ray-verl-cluster + + nodes: 2 + + env: + - WANDB_API_KEY + - PYTHONUNBUFFERED=1 + - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2 + commands: + - git clone https://github.com/volcengine/verl + - cd verl + - pip install --no-deps -e . + - pip install hf_transfer hf_xet + - | + if [ $DSTACK_NODE_RANK = 0 ]; then + python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k + python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-7B-Instruct')" + ray start --head --port=6379; + else + ray start --address=$DSTACK_MASTER_NODE_IP:6379 + fi + + # Expose Ray dashboard port + ports: + - 8265 + + resources: + gpu: 80GB:8 + shm_size: 128GB + + # Save checkpoints on the instance + volumes: + - /checkpoints:/checkpoints + +Now, if you run this task via `dstack apply`, it will automatically forward the Ray's dashboard port to `localhost:8265`. + +.. code-block:: bash + + dstack apply -f ray-cluster.dstack.yml + +As long as the `dstack apply` is attached, you can use `localhost:8265` to submit Ray jobs for execution + +Submit Ray jobs +~~~~~~~~~~~~~~~ + +Before you can submit Ray jobs, ensure to install `ray` locally: + +.. code-block:: shell + + pip install ray + +Now you can submit the training job to the Ray cluster which is available at ``localhost:8265``: + +.. code-block:: shell + + $ RAY_ADDRESS=http://localhost:8265 + $ ray job submit \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files=/root/data/gsm8k/train.parquet \ + data.val_files=/root/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-7B-Instruct \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.project_name=ppo_training \ + trainer.experiment_name=qwen-2.5-7B \ + trainer.val_before_train=False \ + trainer.default_hdfs_dir=null \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=2 \ + trainer.default_local_dir=/checkpoints \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 2>&1 | tee verl_demo.log \ + trainer.resume_mode=disable + + +For more details on how `dstack` works, check out its `documentation `_. + How to debug? --------------------- diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh new file mode 100644 index 00000000000..07f5319d62f --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh @@ -0,0 +1,54 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh new file mode 100644 index 00000000000..ddbf48d5ea7 --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh @@ -0,0 +1,64 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +FUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=4096 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.model.path=Qwen/Qwen2-7B-Instruct \ + critic.model.enable_gradient_checkpointing=True \ + critic.use_dynamic_bsz=True \ + critic.ppo_max_token_len_per_gpu=98304 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.use_dynamic_bsz=True \ + reward_model.forward_max_token_len_per_gpu=98304 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml new file mode 100644 index 00000000000..b3c5dcb922d --- /dev/null +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml @@ -0,0 +1,24 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang_async + multi_turn: + enable: True + max_turns: 5 + format: qwen + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" + \ No newline at end of file diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh new file mode 100644 index 00000000000..122b424456a --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh @@ -0,0 +1,65 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project +# this is a verification training script, the parallel setting should be tuned to your model + +set -x + +export PYTHONUNBUFFERED=1 +export RAY_DEDUP_LOGS=0 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_megatron_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=/user/longxiang1/models/Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.context_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.megatron.seed=42 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.context_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/train.parquet \ + data.val_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 9d15c74c681..95a6eb3cebd 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -26,13 +26,14 @@ from tqdm import tqdm from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import ( compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, ) -from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage +from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask class RayDAPOTrainer(RayPPOTrainer): @@ -208,6 +209,10 @@ def fit(self): traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n batch = batch[:traj_bsz] + # === Updating === + + batch.batch["response_mask"] = compute_response_mask(batch) + # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo @@ -220,6 +225,13 @@ def fit(self): # recompute old_log_probs with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index 27df2cc7438..ccd95d092ef 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -21,6 +21,7 @@ import ray from .dapo_ray_trainer import RayDAPOTrainer +from verl.utils.device import is_cuda_available def get_custom_reward_fn(config): @@ -113,7 +114,6 @@ def run(self, config): role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), - Role.RefPolicy: ray.remote(ActorRolloutRefWorker), } global_pool_id = "global_pool" @@ -123,7 +123,6 @@ def run(self, config): mapping = { Role.ActorRollout: global_pool_id, Role.Critic: global_pool_id, - Role.RefPolicy: global_pool_id, } # we should adopt a multi-source reward function here diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh index 824cdad566f..39918ac2d4b 100644 --- a/recipe/dapo/test_dapo_7b_math.sh +++ b/recipe/dapo/test_dapo_7b_math.sh @@ -2,7 +2,7 @@ set -xeuo pipefail project_name='DAPO' -exp_name='DAPO-Qwen2.5-7b-MATH-0519a1' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' adv_estimator=grpo @@ -27,10 +27,11 @@ n_resp_per_prompt=16 train_prompt_mini_bsz=32 # Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} @@ -53,6 +54,8 @@ offload=True gen_tp=4 fsdp_size=32 +# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model + python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ @@ -71,6 +74,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ @@ -113,12 +117,13 @@ python3 -m verl.trainer.main_ppo \ trainer.logger=['console','wandb'] \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ + trainer.val_before_train=True \ trainer.test_freq=10 \ trainer.save_freq=10 \ trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_qwen3_30b_math.sh b/recipe/dapo/test_dapo_qwen3_30b_math.sh new file mode 100644 index 00000000000..56ebd0397ef --- /dev/null +++ b/recipe/dapo/test_dapo_qwen3_30b_math.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=32 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=300 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 12a5d839bf2..23a3c440369 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -32,7 +32,9 @@ reward_model: model: ref_path: ${reward_model.model.path} use_remove_padding: True - use_fused_kernels: False + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fused_kernel_options: + impl_backend: torch # triton, torch tokenizer_path: ${actor_rollout_ref.model.path} enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} ref_type: freeze diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index f03286e6095..cb603b7a3ed 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -47,11 +47,6 @@ def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, rewa self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - if self.use_fused_kernels: - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - self.fused_linear_for_ppo = FusedLinearForPPO() - def _forward_micro_batch(self, micro_batch, prompt_length): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape @@ -85,14 +80,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.use_fused_kernels: - hidden_states = output.last_hidden_state - vocab_weights = self.reward_module.lm_head.weight - - rm_log_labels, _ = self.fused_linear_for_ppo( - hidden_states=hidden_states.squeeze(0), - vocab_weights=vocab_weights, - input_ids=input_ids_rmpad_rolled, - ) + rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,) rm_log_labels = rm_log_labels.to(torch.float32) else: @@ -115,14 +103,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.use_fused_kernels: - hidden_states = output.last_hidden_state - vocab_weights = self.reward_module.lm_head.weight - - rm_log_labels, _ = self.fused_linear_for_ppo.forward( - hidden_states=hidden_states[:, :-1, :], - vocab_weights=vocab_weights, - input_ids=micro_batch["input_ids"][:, 1:], - ) + rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length) rm_log_labels = rm_log_labels.to(torch.float32) else: @@ -142,18 +123,11 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.use_fused_kernels: - hidden_states = ref_output.last_hidden_state - vocab_weights = self.ref_module.lm_head.weight - - ref_log_labels, _ = self.fused_linear_for_ppo( - hidden_states=hidden_states.squeeze(0), - vocab_weights=vocab_weights, - input_ids=input_ids_rmpad_rolled, - ) + ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,) ref_log_labels = ref_log_labels.to(torch.float32) else: - logits = ref_output.logits.squeeze(0) + ref_output_logits = ref_output.logits.squeeze(0) ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, labels=input_ids_rmpad_rolled) ref_log_labels = gather_outpus_and_unpad(ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) @@ -167,14 +141,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.use_fused_kernels: - hidden_states = ref_output.last_hidden_state - vocab_weights = self.ref_module.lm_head.weight - - ref_log_labels, _ = self.fused_linear_for_ppo.forward( - hidden_states=hidden_states[:, :-1, :], - vocab_weights=vocab_weights, - input_ids=micro_batch["input_ids"][:, 1:], - ) + ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length) ref_log_labels = ref_log_labels.to(torch.float32) else: diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 9cce2855731..1b14cfc741e 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -20,6 +20,7 @@ from torch.distributed.device_mesh import init_device_mesh from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer @@ -36,7 +37,6 @@ offload_fsdp_model_to_cpu, offload_fsdp_optimizer, ) -from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.utils.import_utils import import_external_libs from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -129,11 +129,15 @@ def _build_reward_ref_model_optimizer(self, config): trust_remote_code=trust_remote_code, ) + fused_kernel_options = config.model.get("fused_kernel_options", None) + fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + apply_monkey_patch( model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_remove_padding=config.model.get("use_remove_padding", False), use_fused_kernels=config.model.get("use_fused_kernels", False), + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype diff --git a/recipe/spin/README.md b/recipe/spin/README.md index 56a0873f0fa..0fc35ba7b91 100644 --- a/recipe/spin/README.md +++ b/recipe/spin/README.md @@ -1,40 +1,62 @@ -# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models (verl Recipe) +# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models -This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). The implementation uses an **Online Direct Preference Optimization (Online DPO)** approach for language model alignment. This method allows a model to iteratively improve its capabilities by learning from preferences generated using its own outputs, potentially reducing reliance on external preference datasets or stronger teacher models. +This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory. -Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) +**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models: -verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20) +1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations. +2. **Two-Player Game Setup:** A game involving two players acted by a single LLM. +3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration. + +Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) [[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)] -## Algorithm: Online DPO Inspired by SPIN +verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20) -This recipe implements an Online DPO algorithm adapted to the `verl` Reinforcement Learning framework, drawing inspiration from concepts presented in SPIN. It provides an alternative to PPO for fine-tuning language models. +--- -**Core Idea:** Instead of maximizing a scalar reward signal, this approach directly optimizes the policy model to align with preference data generated *online* during training: +## Key Function (compute_online_dpo_loss) and Related works +SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). -1. **Generation:** The current policy model (actor) generates two (or more) responses for each prompt in a batch. -2. **Preference Labeling:** A reward model or reward function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). -3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using the DPO loss function, comparing against a reference model. +This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data. -**Connection to SPIN:** -While this recipe uses the DPO loss, the online generation loop where the current model generates data used for its own update shares conceptual similarities with the self-play idea in SPIN. The periodic update of the reference model (potentially using weights from the actor) further aligns with SPIN's iterative self-improvement concepts. +Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets. **Reference Papers:** -* **SPIN:** [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) -* **DPO:** [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) +* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) +* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) +* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) +* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023) +* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024) +* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024) + + +## Our Online DPO Implementation + +Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include: -## Implementation within verl -The recipe is expected to be working on verl v0.3.0.post1 +* **No Critic:** Unlike PPO, we omit the value function critic. +* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline. +* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems). +* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences. +* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles. -This implementation adapts the existing PPO infrastructure provided by `verl`: +--- +## Algorithm -* **No Critic:** The value function critic model used in PPO is not required and is omitted. -* **Reference Model:** An explicit reference policy model (`ref_policy_wg`) is maintained and used in the DPO loss calculation. This implementation allows for periodically updating the reference model's weights from the actor model (controlled by `ref_update_freq`). -* **Preference Calculation:** Logic (`compute_onlineDPO_pref` in `core_algos.py`) determines chosen/rejected pairs based on scores from a reward source. -* **DPO Loss:** The PPO policy loss and advantage calculations are replaced with the DPO loss computation (`compute_online_dpo_loss` in `core_algos.py`) within the actor update step (`dp_actor.py`). -* **Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the training loop: generation, preference labeling, optional reference model updates, and policy updates via the DPO loss. +This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models. + +**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training: + +1. **Generation:** The current model generates multiple responses for each prompt in a batch. +2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem). +3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model. + +**Connection with SPIN:** +Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling. + +--- ## Reproduce the Experiment (Example Setup) @@ -73,7 +95,7 @@ The following steps outline how to set up the environment and run the SPIN recip ```bash # Clone the verl repository and checkout the spin branch cd ~ - git clone git@github.com:volcengine/verl.git](git@github.com:volcengine/verl.git) && cd verl + git clone git@github.com:volcengine/verl.git && cd verl # Install flash-attn (handle potential build issues) python3 -m uv pip install wheel packaging @@ -111,6 +133,8 @@ The following steps outline how to set up the environment and run the SPIN recip bash recipe/spin/run_spin.sh ``` +--- + ## Configuration * The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`). @@ -121,10 +145,12 @@ The following steps outline how to set up the environment and run the SPIN recip * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`. * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor). +--- + ## Key Files -* `main_spin.py`: Main entry point using Hydra to load config and launch the `SpinTrainer`. -* `spin_trainer.py`: Defines the `SpinTrainer` class orchestrating the Online DPO training loop. +* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`. +* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop. * `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP. * `dp_actor.py`: Contains the actor class, including the DPO policy update logic. * `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`. @@ -132,17 +158,22 @@ The following steps outline how to set up the environment and run the SPIN recip * `run_spin.sh` (or similar): Example bash script for launching a training run. * `README.md`: This file. +--- + ## Acknowledgement We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO): -- [Yue Wu](https://yuewu.us/) -- [Yuhao Yang](https://github.com/yhyang201) -- [Yifan Zhang](https://github.com/yifanzhang-pro) -- [Yongan Xiang](https://github.com/BearBiscuit05) -- [Junrong Lin](https://github.com/ocss884) -- [Yuxuan Tong](https://github.com/tongyx361) -- [Guangming Shen](https://github.com/PeterSH6) -- [Biao He](https://www.linkedin.com/in/biao-he/) -- [Qingquan Song](https://qingquansong.github.io/) -- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) +* [Zixiang Chen](https://sites.google.com/view/zxchen) +* [Yuhao Yang](https://github.com/yhyang201) +* [Yifan Zhang](https://github.com/yifanzhang-pro) +* [Yongan Xiang](https://github.com/BearBiscuit05) +* [Junrong Lin](https://github.com/ocss884) +* [Yuxuan Tong](https://github.com/tongyx361) +* [Guangming Shen](https://github.com/PeterSH6) +* [Biao He](https://www.linkedin.com/in/biao-he/) +* [Qingquan Song](https://qingquansong.github.io/) +* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/) +* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) + +--- diff --git a/recipe/sppo/main_sppo.py b/recipe/sppo/main_sppo.py index eae1e43e343..25b1c469e7c 100644 --- a/recipe/sppo/main_sppo.py +++ b/recipe/sppo/main_sppo.py @@ -25,6 +25,7 @@ from verl.trainer.ppo.reward import load_reward_manager from .sppo_ray_trainer import RaySPPOTrainer +from verl.utils.device import is_cuda_available @hydra.main(config_path="config", config_name="sppo_trainer", version_base=None) @@ -140,6 +141,7 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name="cuda" if is_cuda_available else "npu", ) trainer.init_workers() trainer.fit() diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index 0e870a0facd..761def940bc 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -86,6 +86,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + device_name="cuda", ): self.tokenizer = tokenizer self.processor = processor @@ -105,6 +106,7 @@ def __init__( self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() + self.device_name = device_name # define in-reward KL control # kl loss control currently not suppoorted diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 00000000000..601e8f9fa6e --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,20 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +hydra-core +numpy +pandas +peft +pyarrow>=15.0.0 +pybind11 +pylatexenc +ray +tensordict<=0.6.2 +transformers>=4.52.0 +wandb +mathruler +torchdata +einops +qwen_vl_utils diff --git a/scripts/model_merger.py b/scripts/model_merger.py index aa0c2e5d292..3bd25cae2ff 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -75,6 +75,7 @@ class ModelMergerConfig: operation: str # 'merge' or 'test' backend: str local_dir: str + hf_model_config_path: str target_dir: Optional[str] = "tmp" hf_upload_path: Optional[str] = None private: bool = False @@ -95,13 +96,13 @@ def __post_init__(self): class BaseModelMerger(ABC): def __init__(self, config: ModelMergerConfig): self.config = config - self.config_path = config.local_dir + self.hf_model_config_path = config.hf_model_config_path if config.hf_model_path: print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ") - self.config_path = config.hf_model_path + self.hf_model_config_path = config.hf_model_path - self.model_config = AutoConfig.from_pretrained(self.config_path) + self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) def get_transformers_auto_model_class(self): if "ForTokenClassification" in self.model_config.architectures[0]: @@ -122,9 +123,9 @@ def patch_model_generation_config(self, model): """ if model.can_generate(): try: - model.generation_config = GenerationConfig.from_pretrained(self.config_path) + model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) except OSError: - print(f"Warning: Generation config file not found in {self.config_path}, using a generation config created from the model config.") + print(f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.") return model def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): @@ -139,8 +140,8 @@ def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): del state_dict del model - processor = hf_processor(self.config_path) - tokenizer = hf_tokenizer(self.config_path) + processor = hf_processor(self.hf_model_config_path) + tokenizer = hf_tokenizer(self.hf_model_config_path) if processor is not None: print(f"Saving processor to {self.config.target_dir}") processor.save_pretrained(self.config.target_dir) @@ -332,6 +333,12 @@ def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): class MegatronModelMerger(BaseModelMerger): + def __init__(self, config: ModelMergerConfig): + from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path + + config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) + super().__init__(config) + def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) assert match, f"Invalid sharded dir {sharded_dir}" @@ -578,6 +585,7 @@ def main(): "is_value_model": args.is_value_model, "local_dir": args.local_dir, "hf_model_path": args.hf_model_path, + "hf_model_config_path": args.local_dir, } if args.operation == "merge": diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index ef1d7c51780..c4f64870356 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -18,7 +18,8 @@ ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} RM_PAD=${RM_PAD:-True} -FUSED_KERNELS=${FUSED_KERNELS:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} USE_KL=${USE_KL:-False} CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} @@ -29,7 +30,7 @@ TEST_FREQ=${TEST_FREQ:--1} # Save & Resume RESUME_MODE=${RESUME_MODE:-disable} SAVE_FREQ=${SAVE_FREQ:--1} -TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} # whether to save hf_model SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} @@ -78,6 +79,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ @@ -115,7 +117,7 @@ python3 -m verl.trainer.main_ppo \ trainer.save_freq="${SAVE_FREQ}" \ trainer.resume_mode="${RESUME_MODE}" \ trainer.total_epochs=2 \ - trainer.total_training_steps="${TOT_TRAIN_STEPS}" $@ \ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \ | tee "${output_file}" if [ "${CUSTOM_REWARD_FN}" = "True" ]; then diff --git a/tests/e2e/ppo_trainer/run_model_reward.sh b/tests/e2e/ppo_trainer/run_model_reward.sh index 19b7c8d1cf9..5e401ad0087 100644 --- a/tests/e2e/ppo_trainer/run_model_reward.sh +++ b/tests/e2e/ppo_trainer/run_model_reward.sh @@ -11,6 +11,8 @@ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} RM_PAD=${RM_PAD:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend SP_SIZE=${SP_SIZE:-1} SEQ_BALANCE=${SEQ_BALANCE:-False} LIGER=${LIGER:-False} @@ -20,7 +22,7 @@ TEST_FREQ=${TEST_FREQ:--1} # Save & Resume RESUME_MODE=${RESUME_MODE:-disable} SAVE_FREQ=${SAVE_FREQ:--1} -TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} train_traj_micro_bsz_per_gpu=2 # b n_resp_per_prompt=4 # g @@ -47,6 +49,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_liger="${LIGER}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \ @@ -94,4 +98,4 @@ python3 -m verl.trainer.main_ppo \ trainer.save_freq="${SAVE_FREQ}" \ trainer.resume_mode="${RESUME_MODE}" \ trainer.total_epochs=2 \ - trainer.total_training_steps="${TOT_TRAIN_STEPS}" $@ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ diff --git a/tests/e2e/run_dapo.sh b/tests/e2e/run_dapo.sh index ef748dd92fb..bdbc40b12c7 100644 --- a/tests/e2e/run_dapo.sh +++ b/tests/e2e/run_dapo.sh @@ -66,6 +66,7 @@ python3 -m recipe.dapo.main_dapo \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh index 364a03723ae..2797b2cf5c5 100644 --- a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh +++ b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh @@ -9,6 +9,7 @@ ulimit -n 65535 PROJECT_DIR="$(pwd)" CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" +FSDP_STRATEGY=${FSDP_STRATEGY:-fsdp} python3 -m verl.trainer.main_ppo \ --config-path="$CONFIG_PATH" \ @@ -30,6 +31,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ @@ -38,12 +40,13 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console'] \ trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-rebased-0427-verify-n16' \ + trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index a70db50ad99..691a9f188de 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -19,8 +19,11 @@ TEST_FREQ=${TEST_FREQ:--1} # Save & Resume RESUME_MODE=${RESUME_MODE:-disable} SAVE_FREQ=${SAVE_FREQ:--1} -TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} +USE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True} +ppo_max_token_len_per_gpu=2400 +forward_max_token_len_per_gpu=4800 train_traj_micro_bsz_per_gpu=2 # b n_resp_per_prompt=4 # g @@ -55,74 +58,95 @@ RM_VPP=${RM_VPP:-$COMMON_VPP} RM_CP=${RM_CP:-$COMMON_CP} RM_TP=${RM_TP:-$TRAIN_TP} +ALL_OFFLOAD=${ALL_OFFLOAD:-False} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} + CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then CHECKPOINT_CONTENTS=['model','optimizer','extra'] fi +ENGINES=("vllm" "sglang_async") + exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator="${ADV_ESTIMATOR}" \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size=${train_prompt_bsz} \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ - actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ - actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.optim.lr=2e-5 \ - critic.model.path="${MODEL_PATH}" \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ - critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ - critic.megatron.context_parallel_size=$CRITIC_CP \ - critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ - critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ - reward_model.enable=True \ - reward_model.model.path="${MODEL_PATH}" \ - reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ - reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ - reward_model.megatron.context_parallel_size=$RM_CP \ - reward_model.megatron.tensor_model_parallel_size=$RM_TP \ - algorithm.use_kl_in_reward=False \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ - trainer.test_freq="${TEST_FREQ}" \ - trainer.save_freq="${SAVE_FREQ}" \ - trainer.resume_mode="${RESUME_MODE}" \ - trainer.total_epochs=2 \ - trainer.total_training_steps="${TOT_TRAIN_STEPS}" $@ \ No newline at end of file +for ENGINE in "${ENGINES[@]}"; do + python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator="${ADV_ESTIMATOR}" \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size=${train_prompt_bsz} \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ + actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \ + actor_rollout_ref.rollout.name="${ENGINE}" \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ + actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.optim.lr=2e-5 \ + critic.model.path="${MODEL_PATH}" \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ + critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ + critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ + critic.megatron.context_parallel_size=$CRITIC_CP \ + critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ + critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ + reward_model.enable=True \ + reward_model.model.path="${MODEL_PATH}" \ + reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ + reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ + reward_model.megatron.context_parallel_size=$RM_CP \ + reward_model.megatron.tensor_model_parallel_size=$RM_TP \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=${NUM_GPUS} \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.test_freq="${TEST_FREQ}" \ + trainer.save_freq="${SAVE_FREQ}" \ + trainer.resume_mode="${RESUME_MODE}" \ + trainer.total_epochs=2 \ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ +done diff --git a/tests/e2e/run_prime.sh b/tests/e2e/run_prime.sh index da7664af320..0d0a8b50a8b 100644 --- a/tests/e2e/run_prime.sh +++ b/tests/e2e/run_prime.sh @@ -34,7 +34,7 @@ python3 -m recipe.prime.main_prime \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=5e-7 \ actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ diff --git a/tests/e2e/run_ray_trainer.sh b/tests/e2e/run_ray_trainer.sh index d6c8451b64c..f9cb19aeb2b 100644 --- a/tests/e2e/run_ray_trainer.sh +++ b/tests/e2e/run_ray_trainer.sh @@ -17,6 +17,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ data.return_raw_input_ids=True \ actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=128 \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.optim.lr=1e-4 \ diff --git a/tests/e2e/run_ray_trainer_rmpad.sh b/tests/e2e/run_ray_trainer_rmpad.sh index e4ca687d024..edab167e652 100644 --- a/tests/e2e/run_ray_trainer_rmpad.sh +++ b/tests/e2e/run_ray_trainer_rmpad.sh @@ -8,6 +8,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ algorithm.adv_estimator=gae \ data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ actor_rollout_ref.rollout.name=vllm \ @@ -16,4 +17,4 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.use_remove_padding=True \ algorithm.use_kl_in_reward=False \ - trainer.total_epochs=1 \ No newline at end of file + trainer.total_epochs=1 diff --git a/tests/e2e/run_sppo.sh b/tests/e2e/run_sppo.sh index 54b6d4c99af..1fa8895a8e9 100644 --- a/tests/e2e/run_sppo.sh +++ b/tests/e2e/run_sppo.sh @@ -24,6 +24,7 @@ python3 -m recipe.sppo.main_sppo \ actor_rollout_ref.model.path="./models/Qwen2.5-0.5B-Instruct" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ @@ -42,4 +43,4 @@ python3 -m recipe.sppo.main_sppo \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ - trainer.total_epochs=2 $@ \ No newline at end of file + trainer.total_epochs=2 $@ diff --git a/tests/kernels/test_linear_cross_entropy.py b/tests/kernels/test_linear_cross_entropy.py index f0fd0e1a63d..cfae0da568d 100644 --- a/tests/kernels/test_linear_cross_entropy.py +++ b/tests/kernels/test_linear_cross_entropy.py @@ -29,12 +29,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import typing import torch import verl.utils.torch_functional as verl_F from verl.utils.experimental.torch_functional import FusedLinearForPPO +from verl.utils.kernel import linear_cross_entropy from verl.utils.torch_functional import logprobs_from_logits compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) @@ -42,10 +44,11 @@ fused_linear_for_ppo.compile(dynamic=True) -def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction="none") -> typing.List[torch.Tensor]: +def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none") -> typing.List[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] @@ -55,10 +58,16 @@ def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch. return logprobs, entropy -def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: +def run_verl_original_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +) -> typing.List[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature # compute entropy entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) @@ -67,23 +76,27 @@ def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels # To be tested -def run_verl_torch_fused_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor): +def run_verl_torch_fused_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +): hidden = hidden.to(torch.float32) weight = weight.to(torch.float32) logprobs, entropy = fused_linear_for_ppo( hidden, weight, labels, + temperature=temperature, ) return logprobs.squeeze(0), entropy.squeeze(0) -MAX_TEST_CASES = 5 - - class TestLinearCrossEntropy: - def __init__(self, test_case_idx: int) -> None: + def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None: self.test_case_idx = test_case_idx + self.temperature = temperature def cleanup(self): torch.cuda.empty_cache() @@ -121,7 +134,7 @@ def generate_hyper(self): self.hidden_size = 4096 self.vocab_size = 102400 else: - raise ValueError(f"Invalid test case index: {test_case_idx}") + raise ValueError(f"Invalid test case index: {self.test_case_idx}") def generate_forward_inputs(self): hidden = torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() @@ -144,6 +157,8 @@ def verify_correctness(self, iterations=5): verl_backward_latency = list() verl_fused_forward_latency = list() verl_fused_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -153,30 +168,44 @@ def verify_correctness(self, iterations=5): hidden, weight, labels = self.generate_forward_inputs() start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels) + (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() verl_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(hidden, weight, labels) + (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + # backward g_entropy, g_logprobs = self.generate_backward_inputs() @@ -198,12 +227,28 @@ def verify_correctness(self, iterations=5): torch.cuda.synchronize() verl_fused_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) # remove first latency torch_forward_latency = torch_forward_latency[1:] @@ -212,6 +257,8 @@ def verify_correctness(self, iterations=5): verl_backward_latency = verl_backward_latency[1:] verl_fused_forward_latency = verl_fused_forward_latency[1:] verl_fused_backward_latency = verl_fused_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] print("\n[INFO]: Verified forward & backward correctness.") @@ -221,6 +268,8 @@ def verify_correctness(self, iterations=5): print(f"[INFO]: Backward pass: VeRL implementation average time: {sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") print(f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: {sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: {sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") def check_storage(self, method_name, run_forward): self.cleanup() @@ -229,7 +278,7 @@ def check_storage(self, method_name, run_forward): hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (logprobs, entropy) = run_forward(hidden, weight, labels) + (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") @@ -246,7 +295,10 @@ def check_storage_all(self): self.check_storage("Torch", run_torch_entropy) self.check_storage("VeRL", run_verl_original_entropy) self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy) + self.check_storage("Kernel", linear_cross_entropy) + +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) if __name__ == "__main__": # torch.cuda.memory._record_memory_history() diff --git a/tests/kernels/test_linear_cross_entropy_tp.py b/tests/kernels/test_linear_cross_entropy_tp.py new file mode 100644 index 00000000000..dfc84214a22 --- /dev/null +++ b/tests/kernels/test_linear_cross_entropy_tp.py @@ -0,0 +1,439 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import typing + +import torch +import torch.distributed as dist + +try: + from verl.utils.kernel import linear_cross_entropy +except ImportError: + # FIXME: remove these manually included paths + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) +finally: + from verl.utils.kernel import linear_cross_entropy + +import verl.utils.torch_functional as verl_F + +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + + +def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none") -> typing.List[torch.Tensor]: + # [num_tokens, vocab_size] + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32)) + logits /= temperature + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] + logprobs = torch.neg(logprobs) + return logprobs, entropy + + +class TorchEntropyTP(torch.autograd.Function): + """ + it is used for testing the correctness of the kernel + it is not efficient and is not recommended to use in practice + """ + + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, dist_process_group: torch.distributed.ProcessGroup): + # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size] + ctx.original_hidden_shape = hidden.shape + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size] + logits /= temperature + whole_logits = torch.empty((logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), dtype=logits.dtype, device=logits.device) + whole_logits_ref = [whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]] for i in range(dist.get_world_size(dist_process_group))] + dist.all_gather(whole_logits_ref, logits, group=dist_process_group) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + + logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.neg(logprobs) + + ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) + ctx.dist_process_group = dist_process_group + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): + hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors + dist_process_group = ctx.dist_process_group + temperature = ctx.temperature + batch_size, hidden_size = hidden.shape + vocab_size, hidden_size = weight.shape + rank = dist.get_rank(dist_process_group) + + # Compute softmax probabilities + maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) + exp_logits = torch.exp(whole_logits - maximum) + accumulate = exp_logits.sum(dim=-1, keepdim=True) + pd = exp_logits / accumulate + + # Gradient for entropy + # entropy = entropy_a - entropy_b + # entropy_a = log(sum(exp(logits))) + # entropy_b = sum(pd * logits) + # d_entropy_a/d_logits = pd + # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = d_entropy_a - d_entropy_b + # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) + d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + + # Gradient for logprobs + # logprobs = -cross_entropy = -log(pd[labels]) + # d_logprobs/d_logits = (pd - one_hot(labels)) + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + g_logprobs = torch.neg(g_logprobs) + d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) + # NOTE: This will lead to wrong result + # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot + + # Combine gradients + d_logits = d_logits_entropy + d_logits_logprobs + d_logits /= temperature + + # Get local slice of gradients + local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size] + + # Compute gradients for hidden and weight + d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32)) + d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32)) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return d_hidden, d_weight, None, None, None + + +run_torch_entropy_tp = TorchEntropyTP.apply + +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) + + +class TestLinearCrossEntropy_TensorParallel: + def __init__(self): + dist.init_process_group(backend="nccl") + self.group = dist.group.WORLD + + self.local_rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") + + def initialize(self, test_case_idx: int, temperature: float = 1.5): + self.test_case_idx = test_case_idx + self.temperature = temperature + + def shutdown(self): + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.dtype = torch.bfloat16 + if self.test_case_idx == 0: + self.batch_size = 1 + self.num_tokens = 1937 + self.hidden_size = 3584 + self.vocab_size = 152064 + elif self.test_case_idx == 1: + self.batch_size = 1 + self.num_tokens = 2169 + self.hidden_size = 896 + self.vocab_size = 151936 + elif self.test_case_idx == 2: + self.batch_size = 1 + self.num_tokens = 1530 + self.hidden_size = 2048 + self.vocab_size = 32256 + elif self.test_case_idx == 3: + self.batch_size = 1 + self.num_tokens = 1388 + self.hidden_size = 4096 + self.vocab_size = 102400 + elif self.test_case_idx == 4: + self.batch_size = 1 + self.num_tokens = 8192 + self.hidden_size = 4096 + self.vocab_size = 102400 + else: + raise ValueError(f"Invalid test case index: {self.test_case_idx}") + + def generate_forward_inputs(self): + hidden = torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() + weight = torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() + labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) + g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) + return g_entropy, g_logprobs + + def verify_torch_itself(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + # forward pass + # Create a tensor to hold the gathered weights from all ranks + # weight has shape [vocab_size, hidden_size] + # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size] + + # Create a single contiguous tensor to hold all gathered weights + whole_weight = torch.empty((self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device) + + # Create views into the tensor for each rank's portion + whole_weight_views = [whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)] + + # Perform all_gather operation using the views + dist.all_gather(whole_weight_views, weight, group=self.group) + + # Set requires_grad for autograd + whole_weight.requires_grad_() + + (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) + + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + + torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + (single_d_hidden, single_d_weight) = torch.autograd.grad((single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False) + + (tp_d_hidden, tp_d_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + # Extract the corresponding slice from single_d_weight for comparison + # tp_d_weight has shape [vocab_size, hidden_size] + # single_d_weight has shape [vocab_size * world_size, hidden_size] + torch.testing.assert_close(tp_d_weight, single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], atol=1e-2, rtol=1e-4) + + # atol=1e-3, rtol=1e-4) + if self.local_rank == 0: + print("[PASS] torch TP correctness is verified") + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + torch.cuda.synchronize() + forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_tp_hidden, d_tp_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + torch.cuda.synchronize() + backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") + print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + + def verify_kernel_correctness(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature, "none", self.group) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + start_event.record() + (torch_d_hidden, torch_d_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + start_event.record() + (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + if self.local_rank == 0: + print("\n[PASS]: Verified kernel forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: {sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: {sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: {sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: {sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature, "none", self.group) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + +if __name__ == "__main__": + # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py + + # Check if running with torchrun (distributed mode) + assert int(os.environ["WORLD_SIZE"]) > 1, "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to execute this script." + torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) + + # set_backward_method(BackwardEnum._Total_Fuse_MN) + # set_backward_method(BackwardEnum._Split_Dlogits_N) + + test = TestLinearCrossEntropy_TensorParallel() + for test_case_idx in range(MAX_TEST_CASES): + print(f"[INFO] Running test case {test_case_idx}") + test.initialize(test_case_idx) + test.verify_torch_itself() + test.check_torch_storage() + test.verify_kernel_correctness() + test.check_kernel_storage() + + test.shutdown() diff --git a/tests/npu/run_qwen2_5_05b_grpo.sh b/tests/npu/run_qwen2_5_05b_grpo.sh new file mode 100644 index 00000000000..d54102b7506 --- /dev/null +++ b/tests/npu/run_qwen2_5_05b_grpo.sh @@ -0,0 +1,44 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=512 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=2 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_32b_grpo.sh b/tests/npu/run_qwen2_5_32b_grpo.sh new file mode 100644 index 00000000000..461b27b80fd --- /dev/null +++ b/tests/npu/run_qwen2_5_32b_grpo.sh @@ -0,0 +1,44 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6\ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_32b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_7b_grpo.sh b/tests/npu/run_qwen2_5_7b_grpo.sh new file mode 100644 index 00000000000..ff173e2b5f6 --- /dev/null +++ b/tests/npu/run_qwen2_5_7b_grpo.sh @@ -0,0 +1,45 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/ray_cpu/test_ray_utils.py b/tests/ray_cpu/test_ray_utils.py new file mode 100644 index 00000000000..e36497d210f --- /dev/null +++ b/tests/ray_cpu/test_ray_utils.py @@ -0,0 +1,54 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import ray + +from verl.utils.ray_utils import parallel_put + + +# Initialize Ray for testing if not already done globally +@pytest.fixture() +def init_ray(): + ray.init(num_cpus=4) + yield + ray.shutdown() + + +def test_parallel_put_basic(init_ray): + data = [1, "hello", {"a": 2}, [3, 4]] + refs = parallel_put(data) + assert len(refs) == len(data) + retrieved_data = [ray.get(ref) for ref in refs] + assert retrieved_data == data + + +def test_parallel_put_empty(init_ray): + data = [] + with pytest.raises(AssertionError): + _ = parallel_put(data) + + +def test_parallel_put_workers(init_ray): + data = list(range(20)) + # Test with specific number of workers + refs = parallel_put(data, max_workers=4) + assert len(refs) == len(data) + retrieved_data = [ray.get(ref) for ref in refs] + assert retrieved_data == data + # Test with default workers (should cap) + refs_default = parallel_put(data) + assert len(refs_default) == len(data) + retrieved_data_default = [ray.get(ref) for ref in refs_default] + assert retrieved_data_default == data diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py index 12a1048d184..e3e0b10dba6 100644 --- a/tests/sandbox/test_sandbox.py +++ b/tests/sandbox/test_sandbox.py @@ -18,7 +18,7 @@ import pytest -from verl.utils.reward_score import _default_compute_score, prime_code, sandbox_fusion +from verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion from verl.utils.reward_score.prime_code import apps_check_correctness from verl.workers.reward_manager.prime import parallel_compute_score_async @@ -109,7 +109,7 @@ def test_parallelism(): ground_truth.extend(prime_math_gts) data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) - scores = asyncio.run(parallel_compute_score_async(_default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) + scores = asyncio.run(parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) print(scores) @@ -119,7 +119,7 @@ def test_prime_code(): """ data_source = "codecontests" for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ @@ -135,7 +135,7 @@ def test_prime_code_sandbox_fusion(): # Removed the previous 'if not sandbox_url' check block for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = _default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable + score = default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable assert float(score) == score_ @@ -153,7 +153,7 @@ def test_continuous_score_consistency(): prime_score, _ = sandbox_fusion.compute_score(os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True) # 2. Calculate score using sandbox_fusion with continuous=True - # Ensure the extra_info key triggers the sandbox_fusion path in _default_compute_score + # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True) # 3. Assert scores are equal (using pytest.approx for float comparison) @@ -175,5 +175,5 @@ def test_check_correctness(): def test_prime_math(): data_source = "numina_aops_forum" for completion, ground_truth in zip(prime_math_answers, prime_math_gts): - score = _default_compute_score(data_source, completion, ground_truth) + score = default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py b/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py new file mode 100644 index 00000000000..cf442a03b58 --- /dev/null +++ b/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py @@ -0,0 +1,47 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.utils.megatron.pipeline_parallel import make_batch_generator + + +def test_make_batch_generator_no_vpp(): + batches = [1, 2, 3] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == batches + + +def test_make_batch_generator_with_vpp(): + batches = [{"data": 1}, {"data": 2}] + vpp_size = 2 + generators = make_batch_generator(batches, vpp_size) + assert isinstance(generators, list) + assert len(generators) == vpp_size + + # Check each generator yields the original batches + for gen in generators: + assert list(gen) == batches + + +def test_make_batch_generator_empty(): + batches = [] + vpp_size = 1 + generator = make_batch_generator(batches, vpp_size) + assert list(generator) == [] + + vpp_size = 3 + generators = make_batch_generator(batches, vpp_size) + assert len(generators) == vpp_size + for gen in generators: + assert list(gen) == [] diff --git a/tests/utils/gpu_tests/test_activation_offload.py b/tests/utils/gpu_tests/test_activation_offload.py new file mode 100644 index 00000000000..c4669063033 --- /dev/null +++ b/tests/utils/gpu_tests/test_activation_offload.py @@ -0,0 +1,143 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import tempfile + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +from torch.distributed import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config + +from verl.utils.activation_offload import enable_activation_offloading +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy + + +def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"): + torch.cuda.set_device(rank) + torch.distributed.init_process_group( + backend="nccl", + init_method=f"file://{rendezvous_file}", + rank=rank, + world_size=world_size, + ) + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) + + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + config = Qwen2Config(num_hidden_layers=4) + + with torch.device("cuda"): + model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = model.to(device="cuda") + + # Wrap model with FSDP + mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + + if strategy == "fsdp": + model = FSDP(model, use_orig_params=False, device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision, device_mesh=device_mesh, auto_wrap_policy=get_fsdp_wrap_policy(module=model)) + else: + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) + fsdp_kwargs = { + "mesh": device_mesh, + "mp_policy": mp_policy, + } + apply_fsdp2(model, fsdp_kwargs, {}) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) + + # Create checkpoint manager + tokenizer = AutoTokenizer.from_pretrained(model_name) + checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer) + + # Generate sample input + batch_size = 2 + seq_len = 32 + vocab_size = 32000 + # First input for initial update + input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + attention_mask1 = torch.ones_like(input_ids1) + + # Second input for verification + input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + attention_mask2 = torch.ones_like(input_ids2) + + # Step 1: Initial update and save checkpoint + outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1) + loss1 = outputs1.logits.mean() + loss1.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Save checkpoint after first update + temp_dir = tempfile.mkdtemp() + checkpoint_path = os.path.join(temp_dir, "checkpoint") + checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) + + # Step 2: Second update and forward pass + outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) + loss2 = outputs2.logits.mean() + loss2.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Record logits after second update + with torch.no_grad(): + logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + + # Step 3: wrap module with activation offloading and load checkpoint + enable_activation_offloading(model, "fsdp") + checkpoint_manager.load_checkpoint(checkpoint_path) + + # Step 4: Repeat the second update with same input + outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2) + loss3 = outputs3.logits.mean() + loss3.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Record logits after loaded checkpoint and update + with torch.no_grad(): + logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + + # Step 4: Verify outputs match + torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0) + print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!") + + # Cleanup + shutil.rmtree(temp_dir) + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +@pytest.mark.parametrize("world_size", (2, 4)) +@pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) +def test_activation_offloading(world_size, strategy, tmp_path): + rendezvous_file = str(tmp_path / "rdzv_file") + os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) + + mp.spawn( + fn=_fsdp_activation_offloading_test, + args=(world_size, rendezvous_file, strategy), + nprocs=world_size, + join=True, + ) diff --git a/verl/__init__.py b/verl/__init__.py index 9f4fa70d77f..d1b8547bca3 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -14,9 +14,13 @@ import logging import os +import pkg_resources +from pkg_resources import DistributionNotFound +from packaging.version import parse as parse_version from .protocol import DataProto from .utils.logging_utils import set_basic_config +from .utils.device import is_npu_available version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) @@ -38,3 +42,17 @@ from modelscope.utils.hf_util import patch_hub patch_hub() + +if is_npu_available: + package_name = 'transformers' + required_version_spec = '4.51.0' + try: + installed_version = pkg_resources.get_distribution(package_name).version + installed = parse_version(installed_version) + required = parse_version(required_version_spec) + + if not installed >= required: + raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.") + except DistributionNotFound: + raise ImportError( + f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}") diff --git a/verl/models/transformers/dense_common.py b/verl/models/transformers/dense_common.py new file mode 100644 index 00000000000..ba31d883c3d --- /dev/null +++ b/verl/models/transformers/dense_common.py @@ -0,0 +1,194 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> CausalLMOutputWithPast: + r""" + Copy paste LLaMa's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py + + This function should be generic enough for all pure text models. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def forward_with_triton_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index e8758e37382..220e83ef07a 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import sys -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import torch @@ -25,7 +24,6 @@ from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -230,65 +228,3 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - - -@dataclass -class CausalLMOutputWithoutLogits(CausalLMOutputWithPast): - last_hidden_state: Optional[torch.FloatTensor] = None - - -def forward_without_logits( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithoutLogits]: - r""" - Copy paste LLaMa's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py - - This function should be generic enough for all pure text models. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if labels is not None: - raise NotImplementedError("forward_without_logits does not support labels") - if not return_dict: - raise NotImplementedError("forward_without_logits has to return_dict") - - return CausalLMOutputWithoutLogits( - last_hidden_state=hidden_states, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 353b4e54691..33594be6e48 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -106,11 +106,53 @@ def _ulysses_flash_attention_forward( return attn_output +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Choose the forward function based on the model and backend. + Args: + model (PreTrainedModel): The model to apply the monkey patch. + use_fused_kernels (bool): Whether to use fused kernels. + fused_kernels_backend (str): The backend to use for fused kernels. + """ + if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: + print(f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is {use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}") + return + + forward_with_torch_backend_function = model.__class__.forward + forward_with_triton_backend_function = model.__class__.forward + if model.config.model_type == "qwen2_5_vl": + from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type == "qwen2_vl": + from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + else: + from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + + if fused_kernels_backend == "triton": + model.__class__.forward = forward_with_triton_backend_function + print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") + elif fused_kernels_backend == "torch": + model.__class__.forward = forward_with_torch_backend_function + print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") + else: + raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + + def apply_monkey_patch( model: PreTrainedModel, ulysses_sp_size: int = 1, use_remove_padding: bool = True, use_fused_kernels: bool = False, + fused_kernels_backend: str = None, ): """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" module = sys.modules[model.__module__] @@ -124,7 +166,6 @@ def apply_monkey_patch( if model.config.model_type == "qwen2_5_vl": from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLFlashAttention2, - Qwen2_5_VLForConditionalGeneration, ) if use_remove_padding or ulysses_sp_size > 1: @@ -133,17 +174,9 @@ def apply_monkey_patch( Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2.5VL") - if use_fused_kernels: - from verl.models.transformers.qwen2_5_vl import forward_without_logits - - Qwen2_5_VLForConditionalGeneration.forward = forward_without_logits - - return - elif model.config.model_type == "qwen2_vl": from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLFlashAttention2, - Qwen2VLForConditionalGeneration, ) if use_remove_padding or ulysses_sp_size > 1: @@ -152,13 +185,6 @@ def apply_monkey_patch( Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2VL") - if use_fused_kernels: - from verl.models.transformers.qwen2_vl import forward_without_logits - - Qwen2VLForConditionalGeneration.forward = forward_without_logits - - return - # transformers<=4.47.1 if use_remove_padding or ulysses_sp_size > 1: if hasattr(module, "_flash_attention_forward"): @@ -171,10 +197,7 @@ def apply_monkey_patch( flash_attention._flash_attention_forward = _ulysses_flash_attention_forward print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") - if use_fused_kernels: - from verl.models.transformers.llama import forward_without_logits - - model.__class__.forward = forward_without_logits + patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) @lru_cache diff --git a/verl/models/transformers/qwen2_5_vl.py b/verl/models/transformers/qwen2_5_vl.py index 98fd367a01f..af3820ceaeb 100644 --- a/verl/models/transformers/qwen2_5_vl.py +++ b/verl/models/transformers/qwen2_5_vl.py @@ -23,18 +23,18 @@ @dataclass -class Qwen2_5_VLCausalLMOutputWithoutLogits(Qwen2_5_VLCausalLMOutputWithPast): - last_hidden_state: Optional[torch.FloatTensor] = None +class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None -def forward_without_logits( +def forward_base_model( self: Qwen2_5_VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -46,16 +46,13 @@ def forward_without_logits( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - **loss_kwargs, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithoutLogits]: +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" Copy paste Qwen2_5_VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: @@ -66,9 +63,7 @@ def forward_without_logits( n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) + raise ValueError(f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}") mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) @@ -84,9 +79,7 @@ def forward_without_logits( n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) + raise ValueError(f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}") mask = input_ids == self.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) @@ -134,16 +127,152 @@ def forward_without_logits( return_dict=return_dict, cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) hidden_states = outputs[0] + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations if labels is not None: - raise NotImplementedError("forward_without_logits does not support labels") + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + +def forward_with_triton_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) + + hidden_states = outputs[0] + if not return_dict: - raise NotImplementedError("forward_without_logits has to return_dict") + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + reduction="none", + temperature=temperature, + ) - return Qwen2_5_VLCausalLMOutputWithoutLogits( - last_hidden_state=hidden_states, + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 91112ca6029..d79a774dcba 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -33,7 +33,7 @@ ) try: - from flash_attn import flash_attn_func, flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: @@ -293,18 +293,18 @@ def ulysses_flash_attn_forward( @dataclass -class Qwen2VLCausalLMOutputWithoutLogits(Qwen2VLCausalLMOutputWithPast): - last_hidden_state: Optional[torch.FloatTensor] = None +class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None -def forward_without_logits( +def forward_base_model( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -315,16 +315,13 @@ def forward_without_logits( video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - **loss_kwargs, -) -> Union[Tuple, Qwen2VLCausalLMOutputWithoutLogits]: +) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Copy paste Qwen2VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: @@ -335,15 +332,8 @@ def forward_without_logits( n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + raise ValueError(f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}") + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -353,15 +343,8 @@ def forward_without_logits( n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + raise ValueError(f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}") + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -397,15 +380,150 @@ def forward_without_logits( cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + hidden_states = outputs[0] + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations if labels is not None: - raise NotImplementedError("forward_without_logits does not support labels") + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + +def forward_with_triton_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if not return_dict: - raise NotImplementedError("forward_without_logits has to return_dict") + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + reduction="none", + temperature=temperature, + ) - return Qwen2VLCausalLMOutputWithoutLogits( - last_hidden_state=hidden_states, + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/verl/protocol.py b/verl/protocol.py index 5b729134cec..64682a4d469 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -36,6 +36,7 @@ from verl.utils.py_functional import union_two_dict from verl.utils.torch_functional import allgather_dict_tensors +from verl.utils.device import get_torch_device __all__ = ["DataProto", "union_tensor_dict"] @@ -272,7 +273,7 @@ def __setstate__(self, data): batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) - batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None) + batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None) self.batch = batch self.non_tensor_batch = non_tensor_batch self.meta_info = meta_info @@ -802,7 +803,7 @@ def all_gather_data_proto(data: DataProto, process_group): group_size = torch.distributed.get_world_size(group=process_group) assert isinstance(data, DataProto) prev_device = data.batch.device - data.batch = data.batch.cuda(device=torch.cuda.current_device()) + data.batch = data.batch.to(get_torch_device().current_device()) data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 5305fd6fb38..7be25fbe087 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -23,6 +23,7 @@ import ray from .decorator import Dispatch, Execute, register +from verl.utils.device import get_torch_device @dataclass @@ -147,10 +148,9 @@ def __init__(self, cuda_visible_devices=None) -> None: import torch from packaging import version - ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES") os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK") ### @@ -168,7 +168,7 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): self.local_rank = int(os.environ["LOCAL_RANK"]) cuda_visible_devices = str(local_rank) ### @@ -188,8 +188,8 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): - torch.cuda.set_device(int(cuda_visible_devices)) + if torch.cuda.is_available() and "AMD" in get_torch_device().get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + get_torch_device().set_device(int(cuda_visible_devices)) ### self.fused_worker_dict = {} diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index c0822e3cf27..ed086fc797e 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -95,13 +95,17 @@ def __init__( self.pgs = None self.detached = detached - def get_placement_groups(self, strategy="STRICT_PACK", name=None): + def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): if self.pgs is not None: return self.pgs pg_name_prefix = name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" # print(f"pg_name_prefix = {pg_name_prefix}") - pg_scheme = [[{"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] + if device_name == "npu": + device_name = "NPU" + elif device_name == "cuda": + device_name = "GPU" + pg_scheme = [[{"CPU": self.max_colocate_count, device_name: 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] lifetime = "detached" if self.detached else None @@ -174,7 +178,7 @@ def update_options(self, options: Dict): """ self._options.update(options) - def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None) -> Any: + def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None, device_name="cuda") -> Any: """Create and return a Ray actor with the configured options. Args: @@ -183,6 +187,7 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = use_gpu: Whether to use GPU resources num_gpus: Number of GPUs to allocate sharing_with: Actor to share resources with + device_name: Device for training Returns: A Ray actor handle with the configured options @@ -196,8 +201,10 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = options = {"scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx)} options.update(self._options) - if use_gpu: + if use_gpu and device_name == "cuda": options["num_gpus"] = num_gpus + if use_gpu and device_name == "npu": + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -227,6 +234,7 @@ def __init__( worker_names=None, worker_handles: List[ray.actor.ActorHandle] = None, ray_wait_register_center_timeout: int = 300, + device_name="cuda", **kwargs, ) -> None: """Initialize a RayWorkerGroup. @@ -249,6 +257,7 @@ def __init__( self.fused_worker_used = ray_cls_with_init.fused_worker_used # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to this WorkerGroup. self.sub_cls_name = "" + self.device_name = device_name if worker_names is not None and (not self.fused_worker_used): assert self._is_init_with_detached_workers @@ -300,7 +309,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d strategy = "PACK" if bin_pack: strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy) + pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size # cia.add_kwarg("_world_size", world_size) @@ -339,7 +348,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d ray_cls_with_init.update_options({"lifetime": "detached"}) # create a worker - worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus) + worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus, device_name=self.device_name) self._workers.append(worker) self._worker_names.append(name) diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 2b1a0941594..56bd824e71f 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -43,6 +43,7 @@ actor_rollout_ref: ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} use_torch_compile: True # False to disable torch compile # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified @@ -81,7 +82,7 @@ actor_rollout_ref: use_distributed_optimizer: True use_dist_checkpointing: False dist_checkpointing_path: null - seed: 1 + seed: 42 override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage profile: # profile the actor model in `update_policy` use_profile: False # open it when you want to profile the actor model @@ -107,7 +108,7 @@ actor_rollout_ref: use_distributed_optimizer: False use_dist_checkpointing: False dist_checkpointing_path: null - seed: 1 + seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} profile: use_profile: False @@ -118,6 +119,8 @@ actor_rollout_ref: load_weight: True log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} rollout: name: vllm mode: sync # sync: LLM, async: AsyncLLM @@ -139,6 +142,8 @@ actor_rollout_ref: max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} disable_log_stats: True enable_chunked_prefill: False # could get higher throughput # for hf rollout @@ -205,13 +210,15 @@ critic: use_distributed_optimizer: True use_dist_checkpointing: False dist_checkpointing_path: null - seed: 1 + seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} load_weight: True ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} shuffle: ${actor_rollout_ref.actor.shuffle} @@ -219,6 +226,7 @@ critic: kl_ctrl: type: fixed kl_coef: 0.001 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space @@ -227,8 +235,6 @@ reward_model: strategy: megatron megatron: param_offload: False - grad_offload: False - optimizer_offload: False tensor_model_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: null @@ -239,7 +245,7 @@ reward_model: use_distributed_optimizer: False use_dist_checkpointing: False dist_checkpointing_path: null - seed: 1 + seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: {} model: input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical @@ -249,6 +255,7 @@ reward_model: micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu micro_batch_size_per_gpu: null use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} max_length: null reward_manager: naive launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob @@ -288,7 +295,7 @@ trainer: resume_from_path: null del_local_ckpt_after_load: False val_before_train: True - test_freq: 2 + test_freq: -1 critic_warmup: 0 default_hdfs_dir: null default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} @@ -296,6 +303,7 @@ trainer: max_critic_ckpt_to_keep: null # The timeout for ray worker group to wait for the register center to be ready ray_wait_register_center_timeout: 300 + device: cuda ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 6d3634ea09e..df1117deee9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -28,9 +28,12 @@ actor_rollout_ref: external_lib: null override_config: { } enable_gradient_checkpointing: True + enable_activation_offload: False use_remove_padding: False use_liger: False use_fused_kernels: False + fused_kernel_options: + impl_backend: torch # triton, torch trust_remote_code: False actor: strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility @@ -153,6 +156,7 @@ critic: override_config: { } external_lib: ${actor_rollout_ref.model.external_lib} enable_gradient_checkpointing: True + enable_activation_offload: False use_remove_padding: False trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} fsdp_config: @@ -177,6 +181,7 @@ critic: shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 cliprange_value: 0.5 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space @@ -188,6 +193,9 @@ reward_model: path: ~/models/FsfairX-LLaMA3-RM-v0.1 external_lib: ${actor_rollout_ref.model.external_lib} use_remove_padding: False + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fused_kernel_options: + impl_backend: ${actor_rollout_ref.model.fused_kernel_options.impl_backend} # triton, torch trust_remote_code: False fsdp_config: wrap_policy: @@ -249,6 +257,7 @@ trainer: max_critic_ckpt_to_keep: null # The timeout for ray worker group to wait for the register center to be ready ray_wait_register_center_timeout: 300 + device: cuda ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 633f2f4f9b4..018b5ca3662 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -30,7 +30,6 @@ import hydra import torch import torch.distributed -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from peft import LoraConfig, TaskType, get_peft_model from tensordict import TensorDict from torch import nn, optim @@ -55,8 +54,15 @@ get_ulysses_sequence_parallel_world_size, ulysses_pad_and_slice_inputs, ) +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) @@ -108,6 +114,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) + self.device_name = get_device_name() def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) @@ -244,7 +251,7 @@ def _build_model_optimizer(self): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), cpu_offload=cpu_offload, use_orig_params=False, ) @@ -280,15 +287,15 @@ def _compute_loss_and_backward(self, batch, do_backward=True): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch["input_ids"].cuda() - attention_mask = batch["attention_mask"].cuda() - position_ids = batch["position_ids"].cuda() - loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda() + input_ids = batch["input_ids"].to(self.device_name) + attention_mask = batch["attention_mask"].to(self.device_name) + position_ids = batch["position_ids"].to(self.device_name) + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) loss_fct = nn.CrossEntropyLoss(reduction="none") # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -398,15 +405,23 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage("After offload weights", logger=logger) - step_loss = torch.tensor(step_loss).cuda() - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} + step_loss = torch.tensor(step_loss).to(self.device_name) + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.ulysses_device_mesh.size(0) + return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): loss = self._compute_loss_and_backward(batch, do_backward=False) - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.ulysses_device_mesh.size(0) return loss def save_checkpoint(self, step): @@ -461,7 +476,7 @@ def fit(self): desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", ): global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -471,7 +486,7 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -487,7 +502,7 @@ def fit(self): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -502,11 +517,12 @@ def fit(self): @hydra.main(config_path="config", config_name="sft_trainer", version_base=None) def main(config): + device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) # build tokenizer and datasets first from verl.utils import hf_tokenizer diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 0d12b5b4b92..0f80b7caf91 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -38,6 +38,7 @@ from verl.utils.hdfs_io import makedirs from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.utils.device import is_cuda_available @hydra.main(config_path="config", config_name="generation", version_base=None) @@ -81,7 +82,7 @@ def main_task(config): ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu") wg.init_model() total_samples = len(dataset) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 4e004fbff31..6cd00b83302 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -178,6 +178,7 @@ def run(self, config): val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index ba9abef52aa..532cb046799 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -75,12 +75,12 @@ def compute_gae_advantage_return( Args: token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) values: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) response_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` + shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma is `(float)` discounted factor used in RL lam: `(float)` lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) @@ -122,9 +122,9 @@ def compute_grpo_outcome_advantage( (with only one scalar reward for each response). Args: token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) response_mask: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) norm_adv_by_std_in_grpo: (bool) whether to scale the GRPO advantage. If True, the advantage is scaled by the std, as in the original GRPO. @@ -132,9 +132,9 @@ def compute_grpo_outcome_advantage( Returns: advantages: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) Returns: `(torch.Tensor)` - shape: (bs, response_length) + shape is (bs, response_length) """ scores = token_level_rewards.sum(dim=-1) @@ -371,15 +371,12 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str """ Aggregate the loss matrix into a scalar. Args: - loss_mat: `(torch.Tensor)` + loss_mat: `(torch.Tensor)`: shape: (bs, response_length) - loss_mask: `(torch.Tensor)` + loss_mask: `(torch.Tensor)`: shape: (bs, response_length) - loss_agg_mode: (str) choices: "token-mean" / - "seq-mean-token-sum" / - "seq-mean-token-mean" / - "seq-mean-token-sum-norm" / - "token-mean" is the default behavior + loss_agg_mode: (str) choices: + method to aggregate the loss matrix into a scalar. Returns: loss: `a scalar torch.Tensor` aggregated loss @@ -414,7 +411,7 @@ def compute_policy_loss( cliprange_low=None, cliprange_high=None, clip_ratio_c=3.0, - loss_agg_mode="token-mean", + loss_agg_mode: str = "token-mean", ): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: @@ -434,11 +431,7 @@ def compute_policy_loss( The higher clip range used in PPO. clip_ratio_c: (float) default: 3.0 The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 - loss_agg_mode: (str) choices: "token-mean" / - "seq-mean-token-sum" / - "seq-mean-token-mean" / - "seq-mean-token-sum-norm" / - "token-mean" is the default behavior + loss_agg_mode: (str) see `agg_loss` Returns: pg_loss: `a scalar torch.Tensor` @@ -475,8 +468,8 @@ def compute_policy_loss( return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower -def compute_entropy_loss(logits, response_mask): - """Compute Categorical entropy loss +def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): + """Compute categorical entropy loss (For backward compatibility) Args: logits: `(torch.Tensor)` @@ -489,12 +482,12 @@ def compute_entropy_loss(logits, response_mask): """ # compute entropy - entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = verl_F.masked_mean(entropy, mask=response_mask) + token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return entropy_loss -def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): +def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean"): """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 Args: @@ -504,6 +497,9 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): Old values of value head, shape (`batch_size`, `response_length`) returns: (`torch.FloatTensor`): Ground truth returns, shape (`batch_size`, `response_length`) + response_mask: `(torch.Tensor)` + Mask for tokens to calculate value function losses. # TODO: Rename to `state_mask`. + loss_agg_mode: (str) see `agg_loss` Returns: vf_loss: a scalar (`torch.FloatTensor`): @@ -515,7 +511,8 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask) + clipped_vf_losses = torch.max(vf_losses1, vf_losses2) + vf_loss = agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index ebb19d5829e..de727c2ce5a 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -124,7 +124,7 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} + node_available_gpus = {node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) for node, node_info in node_available_resources.items()} # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) @@ -146,6 +146,22 @@ def _check_resource_available(self): def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ responses = data.batch["responses"] response_length = responses.size(1) token_level_scores = data.batch["token_level_scores"] @@ -179,6 +195,17 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ responses = data.batch["responses"] response_length = responses.size(1) attention_mask = data.batch["attention_mask"] @@ -186,6 +213,23 @@ def compute_response_mask(data: DataProto): def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True): + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ # Back-compatible with trainers that do not compute response mask in fit if "response_mask" not in data.batch: data.batch["response_mask"] = compute_response_mask(data) @@ -266,6 +310,18 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): + """Context manager for timing code execution. + + This utility function measures the execution time of code within its context + and accumulates the timing information in the provided dictionary. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + + Yields: + None: This is a context manager that yields control back to the code block. + """ with Timer(name=name, logger=None) as timer: yield if name not in timing_raw: @@ -294,8 +350,9 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, + device_name="cuda", ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' + """Initialize distributed PPO trainer with Ray backend.""" self.tokenizer = tokenizer self.processor = processor @@ -314,6 +371,7 @@ def __init__( self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name self.validation_generations_logger = ValidationGenerationsLogger() # define in-reward KL control @@ -679,7 +737,12 @@ def _validate(self): return metric_dict def init_workers(self): - """Init resource pool and worker group""" + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} @@ -727,7 +790,7 @@ def init_workers(self): for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) @@ -978,6 +1041,30 @@ def fit(self): old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) + if self.use_reference_policy: # compute reference log_prob with _timer("ref", timing_raw): diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 23d4b1e70fa..7f6910ef35f 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -19,7 +19,7 @@ import ray from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score def get_custom_reward_fn(config): @@ -87,9 +87,9 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): if sandbox_url: sandbox_manager = multiprocessing.Manager() _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) - final_compute_score = partial(_default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) + final_compute_score = partial(default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) else: - final_compute_score = _default_compute_score + final_compute_score = default_compute_score return reward_manager_cls( tokenizer=tokenizer, diff --git a/verl/utils/activation_offload.py b/verl/utils/activation_offload.py new file mode 100644 index 00000000000..e07ee262609 --- /dev/null +++ b/verl/utils/activation_offload.py @@ -0,0 +1,551 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" + +from __future__ import annotations + +import functools +import logging +import os +from typing import Any, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl.utils.fsdp_utils import FSDPModule as FSDP2 + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +class FSDPParameterFilter: + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_push.") + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_pop.") + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with torch.cuda.stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f"{group_id} {state}" + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context(num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True)): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret,) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == "fsdp" or strategy == "fsdp2", "activation offloading only supports fsdp strategy" + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, (FSDP, FSDP2)): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading") + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, "gradient_checkpointing_disable"): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index 1914d475513..076a319bbca 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -23,6 +23,8 @@ from filelock import FileLock from transformers import PreTrainedTokenizer, ProcessorMixin +from verl.utils.device import is_cuda_available, is_npu_available + class BaseCheckpointManager: """ @@ -107,21 +109,42 @@ def local_mkdir(path): def get_rng_state(): rng_state = { "cpu": torch.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), "numpy": np.random.get_state(), "random": random.getstate(), } + + if is_cuda_available: + rng_state["cuda"] = torch.cuda.get_rng_state() + elif is_npu_available: + rng_state["npu"] = torch.npu.get_rng_state() + return rng_state @staticmethod def load_rng_state(rng_state): torch.set_rng_state(rng_state["cpu"]) - torch.cuda.set_rng_state(rng_state["cuda"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) + if is_cuda_available: + torch.cuda.set_rng_state(rng_state["cuda"]) + elif is_npu_available: + torch.npu.set_rng_state(rng_state["npu"]) + def find_latest_ckpt_path(path, directory_format="global_step_{}"): + """ + Return the most recent checkpoint directory based on a tracker file. + + Args: + path (str): Base directory containing the checkpoint tracker. + directory_format (str): Template for checkpoint subfolders with one + placeholder for the iteration number (default "global_step_{}"). + + Returns: + str or None: Full path to the latest checkpoint directory, or + None if the tracker or checkpoint folder is missing. + """ if path is None: return None diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index bf9206d5498..f5980129e91 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -23,6 +23,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin +from verl.utils.device import is_cuda_available from verl.utils.fs import copy_to_local, is_non_local from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx @@ -31,17 +32,20 @@ class FSDPCheckpointManager(BaseCheckpointManager): """ - A checkpoint manager that saves and loads - - model - - optimizer - - lr_scheduler - - extra_states - in a SPMD way. - - We save - - sharded model states and optimizer states - - full lr_scheduler states - - huggingface tokenizer/processor and config for ckpt merge + Manage FSDP checkpointing in SPMD training. + + - Saves/loads per-rank sharded model & optimizer states + - Persists full lr_scheduler and RNG state + - Stores HF tokenizer/processor and model/config for unified restore + + Args: + model (FSDP): Wrapped model instance. + optimizer (Optimizer): Training optimizer. + lr_scheduler (LRScheduler): Learning-rate scheduler. + processing_class (PreTrainedTokenizer or ProcessorMixin, optional): + Pre-/post-processing artifact handler. + checkpoint_contents (list[str], optional): + Components to include; must contain 'model', 'optimizer', 'extra'. """ def __init__( @@ -70,6 +74,18 @@ def __init__( ) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + """ + Load an FSDP checkpoint for this rank. + + Downloads and loads: + - model and optimizer shards + - extra state dict (scheduler + RNG) + + Args: + local_path: Directory with per-rank checkpoint files. + hdfs_path: Unused (for API compatibility). + del_local_after_load: Remove local files after loading. + """ if local_path is None: return @@ -96,8 +112,8 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: @@ -111,6 +127,23 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + """ + Save an FSDP checkpoint for this rank. + + Writes: + - model & optimizer shard files + - extra state dict (scheduler + RNG) + - HF tokenizer/processor and model/config on rank 0 + - optional full HF model under 'huggingface/' if requested + + Rotates old checkpoints, keeping at most `max_ckpt_to_keep`. + + Args: + local_path: Target directory for checkpoint files. + hdfs_path: Unused (for API compatibility). + global_step: Current training step (used for bookkeeping). + max_ckpt_to_keep: Number of recent checkpoints to retain. + """ if local_path is None: return @@ -127,8 +160,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.distributed.barrier() # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with warnings.catch_warnings(): warnings.simplefilter("ignore") with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index e25c6cdd5c3..d24ed91abac 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -21,10 +21,12 @@ import torch.distributed from megatron.core import mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject +from transformers import GenerationConfig from verl.models.weight_loader_registry import get_weight_saver from verl.utils.fs import is_non_local from verl.utils.megatron_utils import ( + get_hf_config_and_tokenizer_checkpoint_path, get_hf_model_checkpoint_path, get_model_checkpoint_path, get_optimizer_checkpoint_path, @@ -240,19 +242,28 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i print(f"Saving sharded model checkpoint to {local_path}") model_ckpt_path = get_model_checkpoint_path(local_path) - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + hf_config_and_tokenizer_path = get_hf_config_and_tokenizer_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) torch.save(state_dicts, os.path.join(ckpt_name)) + print(f"Saved checkpoint to {model_ckpt_path}") if self.rank == 0: - self.processing_class.save_pretrained(hf_model_ckpt_path) # tokenizer will be saved to hf_model_ckpt_path - print(f"Saved tokenizer to {hf_model_ckpt_path}") + self.processing_class.save_pretrained(hf_config_and_tokenizer_path) + self.hf_config.save_pretrained(hf_config_and_tokenizer_path) + if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: + try: + generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path) + generation_config.save_pretrained(hf_config_and_tokenizer_path) + except Exception: + # if the generation config isn't available, we don't save it + pass if hdfs_path is not None: print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) if "hf_model" in self.checkpoint_contents: # wait for everyone to dump to local @@ -286,6 +297,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + self.processing_class.save_pretrained(hf_model_ckpt_path) + if hdfs_path is not None: print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index a7cd183945b..e952af5e057 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -35,7 +35,17 @@ def collate_fn(data_list: list[dict]) -> dict: - """Collate a batch of data.""" + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, *dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ tensors = defaultdict(list) non_tensors = defaultdict(list) @@ -57,7 +67,19 @@ def collate_fn(data_list: list[dict]) -> dict: class RLHFDataset(Dataset): """ - We assume the dataset contains a column that contains prompts and other information + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. """ def __init__( @@ -247,10 +269,10 @@ def __getitem__(self, item): # encode prompts without chat template if self.return_raw_chat: row_dict["raw_prompt"] = messages - + # get prompts with chat template if self.return_full_prompt: - row_dict["full_prompts"] = raw_prompt # array of strings + row_dict["full_prompts"] = raw_prompt # array of strings # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index f461b3f47a8..dee5e7e6099 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -19,18 +19,19 @@ import torch.distributed as dist from verl.utils.logger.aggregate_logger import DecoratorLoggerBase +from verl.utils.device import get_torch_device def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: """Get current memory usage.""" assert unit in ["GB", "MB", "KB"] divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 - mem_allocated = torch.cuda.memory_allocated() - mem_reserved = torch.cuda.memory_reserved() - # use torch.cuda.mem_get_info to profile device memory + mem_allocated = get_torch_device().memory_allocated() + mem_reserved = get_torch_device().memory_reserved() + # use get_torch_device().mem_get_info to profile device memory # since vllm's sleep mode works below pytorch # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 - mem_free, mem_total = torch.cuda.mem_get_info() + mem_free, mem_total = get_torch_device().mem_get_info() mem_used = mem_total - mem_free mem_allocated = f"{mem_allocated / divisor:.{precision}f}" mem_reserved = f"{mem_reserved / divisor:.{precision}f}" @@ -53,17 +54,12 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging class GPUMemoryLogger(DecoratorLoggerBase): """A decorator class to log GPU memory usage. - Usage: - For example, in actor function, we initialize a GPUMemoryLogger - - ``` - from verl.utils.debug.performance import GPUMemoryLogger - @GPUMemoryLogger(role="actor") - def update_actor(self, batch): - # do something - return - ``` - + Example: + >>> from verl.utils.debug.performance import GPUMemoryLogger + >>> @GPUMemoryLogger(role="actor") + >>> def update_actor(self, batch): + ... # real actor update logics + ... return """ def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): diff --git a/verl/utils/device.py b/verl/utils/device.py new file mode 100644 index 00000000000..ee9e279d212 --- /dev/null +++ b/verl/utils/device.py @@ -0,0 +1,57 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 7aa30c16815..101d972b6b0 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -14,6 +14,7 @@ """Utilities for distributed training.""" import os +from verl.utils.device import is_cuda_available, get_torch_device def initialize_global_process_group(timeout_second=36000): @@ -21,11 +22,11 @@ def initialize_global_process_group(timeout_second=36000): import torch.distributed - torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group("nccl" if is_cuda_available else "hccl", timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + get_torch_device().set_device(local_rank) return local_rank, rank, world_size diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 5eefbb7400e..c25bec89b2a 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig +from verl.utils.device import is_cuda_available, get_torch_device VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "deepseek_v3"} @@ -29,7 +30,7 @@ def unit_convert(number, level): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = get_torch_device().get_device_name() flops = float("inf") # INF flops for unkown gpu type if "MI300X" in device_name: diff --git a/verl/utils/fs.py b/verl/utils/fs.py index 9d280177cbc..a2a790968ac 100644 --- a/verl/utils/fs.py +++ b/verl/utils/fs.py @@ -32,10 +32,29 @@ def is_non_local(path): + """Check if a path is a non-local (HDFS) path. + + Args: + path (str): The path to check. + + Returns: + bool: True if the path is an HDFS path, False otherwise. + """ return path.startswith(_HDFS_PREFIX) def md5_encode(path: str) -> str: + """Generate an MD5 hash of a path string. + + This function is used to create unique identifiers for paths, typically + for creating cache directories or lock files. + + Args: + path (str): The path to encode. + + Returns: + str: The hexadecimal MD5 hash of the path. + """ return hashlib.md5(path.encode()).hexdigest() diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index c645cfe9787..53af8798736 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -29,6 +29,7 @@ from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name +from verl.utils.device import get_torch_device, get_device_name if version.parse(torch.__version__) >= version.parse("2.6"): from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard @@ -40,8 +41,8 @@ def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=get_torch_device().current_device(), recurse=False) + get_torch_device().empty_cache() return x @@ -144,7 +145,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -152,7 +153,7 @@ def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): for param in model.parameters(): param.data = param.data.to(torch.device("cpu"), non_blocking=True) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -165,12 +166,12 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model loading to GPU" - device_id = torch.cuda.current_device() + device_id = get_torch_device().current_device() for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True) + handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -279,7 +280,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = get_torch_device().current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -317,7 +318,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = get_torch_device().current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py new file mode 100644 index 00000000000..805759d4795 --- /dev/null +++ b/verl/utils/kernel/__init__.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .kernels import BackwardEnum, set_backward_method +from .linear_cross_entropy import linear_cross_entropy + +__all__ = ["linear_cross_entropy", "set_backward_method", "BackwardEnum"] diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py new file mode 100644 index 00000000000..c41b87088a8 --- /dev/null +++ b/verl/utils/kernel/kernels.py @@ -0,0 +1,1391 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementations of the linear cross entropy with token entropy kernel. +""" + +import typing +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + + +@dataclass +class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ + + _None = 0 + _Sum = 1 + _Mean = 2 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if reduction == "none": + _enum = EntropyReductionEnum._None + elif reduction == "sum": + _enum = EntropyReductionEnum._Sum + elif reduction == "mean": + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if ce_reduction == 0: + _enum = EntropyReductionEnum._None + elif ce_reduction == 1: + _enum = EntropyReductionEnum._Sum + elif ce_reduction == 2: + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + return _enum + + +@dataclass +class BackwardEnum: + """ + Enum for the backward method. + """ + + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size + _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens + + +@dataclass +class Config: + _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N + _use_triton: bool = True + + +_config = Config() + + +def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ + global _config + _config._backward = backward_method + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + rcp_temperature: tl.float32, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over K dimension + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + # load the next block of hidden and weight + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( + # (pid_n + 1) * vocab_per_split, vocab_size))), + # other=0.0) + + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), other=0.0) + + # GEMM + logits = tl.dot(_hidden, _weight.trans(), logits) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # scale logits by temperature + logits *= rcp_temperature + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= offs_am < num_tokens + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue( + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + num_tokens, + num_splits, + global_max_ptr, + stride_global_max: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + global_entropy_ptr, + stride_global_entropy: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + _entropy_b = tl.load(entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, + num_splits, + reduced_max_ptr, + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_max_ptr, + stride_global_max: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load(reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _original_max = tl.load(original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _accu = tl.load(accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load(entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update( + num_tokens, logprobs_ptr, stride_logprobs: tl.int64, maximum_ptr, stride_maximum: tl.int64, accumulate_ptr, stride_accumulate: tl.int64, entropy_b_ptr, stride_entropy_b: tl.int64, entropy_ptr, stride_entropy: tl.int64, logprobs_scalar_ptr, reduction: int, BLOCK_SIZE_M: tl.constexpr +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction: typing.Optional[int] = 2, temperature: typing.Optional[float] = 1.0, dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """ + forward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = torch.cuda.Stream(hidden.device) + _dedicated_events = [torch.cuda.Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + if _config._use_triton: + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid]( + _rank, + hidden, + weight, + labels, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + hidden.stride(0), + hidden.stride(1), + weight.stride(0), + weight.stride(1), + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + _logprobs, + _logprobs.stride(0), + logprobs, + 1.0 / temperature, + ) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( + _max, + _max.stride(0), + _max.stride(1), + num_tokens, + num_splits, + maximum, + maximum.stride(0), + _accu, + _accu.stride(0), + _accu.stride(1), + accumulate, + accumulate.stride(0), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + _logprobs, + _logprobs.stride(0), + logprobs, + REDUCTION, + ) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + torch.cuda.current_stream().record_event(_dedicated_events[0]) + with torch.cuda.stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + ) + torch.cuda.current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid](num_tokens, _logprobs, _logprobs.stride(0), maximum, maximum.stride(0), accumulate, accumulate.stride(0), entropy_b, entropy_b.stride(0), entropy, entropy.stride(0), logprobs, REDUCTION) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n + d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + # tl.atomic_add(d_weight_ptrs, + # _d_weight, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) + _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) + tl.atomic_add(d_weight_ptrs, _d_weight, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size)) + + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) + tl.atomic_add(d_hidden_ptrs, _d_hidden, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens)) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_hidden( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_hidden + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_k = pid // num_pid_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + # iterate over vocab_size + d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): + offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over hidden_size to get logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits + d_logits *= rcp_temperature + + # calculate d_hidden + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) + _weight = tl.load(weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0) + d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) + + # write back + tl.store(d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, d_hidden, mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_weight( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + pid_n = pid % num_pid_n + pid_k = pid // num_pid_n + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): + offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) + _hidden = tl.load(hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0) + d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) + + # write back + tl.store(d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, d_weight, mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size)) + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store( + d_logits_ptrs, + d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits_split_N( + split_idx: int, + num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), other=0.0) + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + # filter d_logits with mask + result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) + + tl.store(d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + should_return_fp32_grad: bool = False, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> typing.List[torch.Tensor]: + """ + backward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + d_hidden, d_weight = None, None + if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + else: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + if _config._backward == BackwardEnum._Total_Fuse_MN: + # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), + 1.0 / temperature, + ) + + elif _config._backward == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + if _config._use_triton: + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + torch.matmul(_d_logits, weight, out=d_hidden) + torch.matmul(_d_logits.T, hidden, out=d_weight) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + elif _config._backward == BackwardEnum._Split_Dlogits_N: + vocab_per_split = 9504 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + _d_logits = _d_logits[:, :vocab_right_bound].contiguous() + + if split_idx == 0: + torch.matmul(_d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden) + else: + d_hidden += torch.matmul(_d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]) + torch.matmul(_d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]) + + elif _config._backward == BackwardEnum._Split_Dlogits_M: + raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") + + return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py new file mode 100644 index 00000000000..8a7d43ec329 --- /dev/null +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,94 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import torch +import torch.distributed as dist + +from . import kernels + + +class LinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: typing.Optional[float] = 1.0, reduction: typing.Optional[str] = "none", dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """_summary_ + + Args: + ctx (_type_): _description_ + hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size) + weight (torch.Tensor): (vocab_size, hidden_size) + labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, ) + temperature (typing.Optional[float], optional): _description_. Defaults to 1.0. + reduction (typing.Optional[str], optional): _description_. Defaults to "none". + dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None. + + Returns: + typing.List[torch.Tensor]: _description_ + """ + + assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" + assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + original_hidden_shape = hidden.shape + if len(hidden.shape) != 2: + hidden = hidden.view(-1, hidden.shape[-1]) # (batch_size * num_tokens, hidden_size) + if len(labels.shape) != 1: + labels = labels.view(-1) + + logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward(hidden, weight, labels, REDUCTION, temperature, dist_process_group) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors + REDUCTION = ctx.REDUCTION + dist_process_group = ctx.dist_process_group + should_return_fp32_grad = ctx.should_return_fp32_grad + temperature = ctx.temperature + + d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, _maximum, _accumulate, _entropy_b, REDUCTION, should_return_fp32_grad, temperature, dist_process_group) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return (d_hidden, d_weight, None, None, None, None) + + +linear_cross_entropy = LinearCrossEntropy.apply diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py index b7e272763ff..50ba6973625 100644 --- a/verl/utils/megatron/pipeline_parallel.py +++ b/verl/utils/megatron/pipeline_parallel.py @@ -47,6 +47,20 @@ def compute_transformers_input_shapes(batches, meta_info): def make_batch_generator(batches, vpp_size): + """ + Creates a batch generator suitable for Megatron pipeline parallelism, + handling virtual pipeline parallelism (VPP). + + If VPP is used (vpp_size > 1), it duplicates the batch iterator for each + virtual pipeline stage. Otherwise, it returns a single iterator. + + Args: + batches: An iterable (e.g., list) of micro-batches. + vpp_size (int): The virtual pipeline model parallel size. + + Returns: + An iterator or a list of iterators over the micro-batches. + """ if vpp_size > 1: # has vpp batch_generator = [batches] * vpp_size # number of vpp chunks diff --git a/verl/utils/megatron/sequence_parallel.py b/verl/utils/megatron/sequence_parallel.py index 9f4cbc08e87..52fda9b30cc 100644 --- a/verl/utils/megatron/sequence_parallel.py +++ b/verl/utils/megatron/sequence_parallel.py @@ -33,6 +33,7 @@ def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): unpad_tokens: (total_nnz, ...). Tokens after removing padding Returns: + the padded tokens: (total_nnz + pad_size,...) """ total_nnz = unpad_tokens.shape[0] diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 508ce6608d6..84d11a8f99a 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -1,5 +1,7 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +27,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.enums import ModelType -from megatron.core.optimizer import OptimizerConfig +from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_attr_wrapped_model @@ -296,12 +298,18 @@ def load_megatron_model_to_gpu(models, load_grad=True): @torch.no_grad() def offload_megatron_copy_params(optimizers): """ - Offload optimizer parameters to CPU + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. Args: - optimizers: The optimizer containing parameter groups to offload + optimizers: The optimizer or ChainedOptimizer instance. """ + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + def offload_tensor_to_cpu(tensor): if tensor is None: return @@ -321,21 +329,27 @@ def offload_group_to_cpu(group): else: offload_tensor_to_cpu(group) - # Offload all parameter groups to CPU + # Offload all parameter groups to CPU for each underlying optimizer - if hasattr(optimizers, "shard_fp32_from_float16_groups"): - offload_group_to_cpu(optimizers.shard_fp32_from_float16_groups) + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) @torch.no_grad() def load_megatron_copy_params(optimizers): """ - Load optimizer parameters back to GPU + Load optimizer parameters back to GPU. Handles ChainedOptimizer. Args: - optimizers: The optimizer containing parameter groups to load + optimizers: Optimizer or ChainedOptimizer instance. """ + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + def load_tensor_to_gpu(tensor): if tensor is None: return @@ -356,36 +370,49 @@ def load_group_to_gpu(group): else: load_tensor_to_gpu(group) - # Load all parameter groups to GPU + # Load all parameter groups to GPU for each underlying optimizer - if hasattr(optimizers, "shard_fp32_from_float16_groups"): - load_group_to_gpu(optimizers.shard_fp32_from_float16_groups) + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) @torch.no_grad() def offload_megatron_optimizer(optimizers): - offload_megatron_copy_params(optimizers) - opt_state_dict_values = optimizers.optimizer.state.values() - for v in opt_state_dict_values: - if 'exp_avg' in v: - v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) - if 'exp_avg_sq' in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) - gc.collect() - torch.cuda.empty_cache() + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() @torch.no_grad() def load_megatron_optimizer(optimizers): - load_megatron_copy_params(optimizers) - opt_state_dict_values = optimizers.optimizer.state.values() - for v in opt_state_dict_values: - if 'exp_avg' in v: - v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True) - if 'exp_avg_sq' in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True) - gc.collect() - torch.cuda.empty_cache() + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True) + gc.collect() + torch.cuda.empty_cache() def print_rank_0(message): @@ -407,6 +434,11 @@ def get_hf_model_checkpoint_path(checkpoint_path): return os.path.join(checkpoint_path, "huggingface") +def get_hf_config_and_tokenizer_checkpoint_path(checkpoint_path): + os.makedirs(checkpoint_path, exist_ok=True) + return os.path.join(checkpoint_path, "hf_config_and_tokenizer") + + def get_optimizer_checkpoint_path(checkpoint_path, use_distributed_optimizer=True): os.makedirs(os.path.join(checkpoint_path, "optim"), exist_ok=True) if not use_distributed_optimizer: @@ -650,7 +682,7 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m v_lst = [] assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0 + assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: @@ -664,10 +696,7 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m q = torch.cat(q_lst, dim=0) k = torch.cat(k_lst, dim=0) v = torch.cat(v_lst, dim=0) - if not convert_qkv_gate_up_by_simple_split: - infer_params = torch.cat((q, k, v), dim=0) - else: - infer_params = [q, k, v] + infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] elif layer_name_mapping.get("gate_proj_layer_name") in name: # if the tensor is gate and proj @@ -679,10 +708,10 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m up_lst.append(up) gate = torch.cat(gate_lst, dim=0) up = torch.cat(up_lst, dim=0) - if not convert_qkv_gate_up_by_simple_split: - infer_params = torch.cat((gate, up), dim=0) - else: - infer_params = [gate, up] + infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] + + elif "mlp.experts.linear_fc2.weight" in name: # moe + infer_params = torch.cat(infer_params, dim=1) else: # concat tensor @@ -691,11 +720,10 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m return infer_params -def per_tensor_generator(actor_module, model_config, weight_converter, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True): +def per_tensor_generator(actor_module, model_config, weight_converter, transformer_config, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True): from megatron.core import parallel_state as mpu pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() ep_size = mpu.get_expert_model_parallel_world_size() etp_size = mpu.get_expert_tensor_parallel_world_size() ep_group = mpu.get_expert_model_parallel_group() @@ -722,12 +750,18 @@ def tensor_generator(): # lazy load tensor for full model for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: + if model_config.tie_word_embeddings and ("output_layers" in name): + import warnings + + warnings.warn("Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2) + continue + if cur_pp_rank == pp_rank: try: cur_name, cur_tensor = next(gen_func) except StopIteration: cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, model_config.num_hidden_layers) + cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config) else: cur_tensor, cur_name = None, None diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 23c12a93b6d..69c45f2f032 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -157,6 +157,18 @@ def union_two_dict(dict1: Dict, dict2: Dict): def append_to_dict(data: Dict, new_data: Dict): + """Append values from new_data to lists in data. + + For each key in new_data, this function appends the corresponding value to a list + stored under the same key in data. If the key doesn't exist in data, a new list is created. + + Args: + data (Dict): The target dictionary containing lists as values. + new_data (Dict): The source dictionary with values to append. + + Returns: + None: The function modifies data in-place. + """ for key, val in new_data.items(): if key not in data: data[key] = [] @@ -164,6 +176,21 @@ def append_to_dict(data: Dict, new_data: Dict): class NestedNamespace(SimpleNamespace): + """A nested version of SimpleNamespace that recursively converts dictionaries to namespaces. + + This class allows for dot notation access to nested dictionary structures by recursively + converting dictionaries to NestedNamespace objects. + + Example: + config_dict = {"a": 1, "b": {"c": 2, "d": 3}} + config = NestedNamespace(config_dict) + # Access with: config.a, config.b.c, config.b.d + + Args: + dictionary: The dictionary to convert to a nested namespace. + **kwargs: Additional attributes to set on the namespace. + """ + def __init__(self, dictionary, **kwargs): super().__init__(**kwargs) for key, value in dictionary.items(): diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index 49b60ef45c6..5d08b83ec66 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -16,11 +16,26 @@ """ import concurrent.futures +from typing import Any, List, Optional import ray -def parallel_put(data_list, max_workers=None): +def parallel_put(data_list: List[Any], max_workers: Optional[int] = None): + """ + Puts a list of data into the Ray object store in parallel using a thread pool. + + Args: + data_list (List[Any]): A list of Python objects to be put into the Ray object store. + max_workers (int, optional): The maximum number of worker threads to use. + Defaults to min(len(data_list), 16). + + Returns: + List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list, + maintaining the original order. + """ + assert len(data_list) > 0, "data_list must not be empty" + def put_data(index, data): return index, ray.put(data) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 55b8621f3ff..1466e498d88 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -13,8 +13,25 @@ # limitations under the License. # from . import gsm8k, math, prime_math, prime_code +from verl.utils.import_utils import deprecated -def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): + +def default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): + """Compute the score for a given solution based on the data source. + + Args: + data_source (str): The source dataset identifier which determines the scoring method. + solution_str (str): The solution string to be evaluated. + ground_truth (str): The ground truth answer for comparison. + extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None. + + Returns: + float: The computed score as a floating point number. If the result is a dictionary, + it returns the dictionary instead. + + Raises: + NotImplementedError: If the reward function is not implemented for the given data source. + """ if data_source == "openai/gsm8k": from . import gsm8k @@ -71,3 +88,14 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N return float(res) else: return float(res[0]) + + +@deprecated("verl.utils.reward_score.default_compute_score") +def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): + """ + Legacy function API to be deprecated. Please use `default_compute_score` instead. + """ + return default_compute_score(data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore) + + +__all__ = ["default_compute_score"] diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 82dc837645d..e2e567050da 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -141,20 +141,30 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): - """get order of seq lengths to make partitions balanced, this is - used in balacing sum of seqlength across dp ranks and microbatches - Parameters: - seqlen_list (List[int]): - seq lengths of each items - k_partitions (int): - resulting number of partitions - equal_size (bool): - if True, number of items in each partitions must be equal. - if False, only consider balancing the sum, each partition can have - variable number of items + """ + Calculates partitions of indices from seqlen_list such that the sum of sequence lengths + in each partition is balanced. Uses the Karmarkar-Karp differencing method. + + This is useful for balancing workload across devices or batches, especially when + dealing with variable sequence lengths. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + k_partitions (int): The desired number of partitions. + equal_size (bool): If True, ensures that each partition has the same number of items. + Requires len(seqlen_list) to be divisible by k_partitions. + If False, partitions can have varying numbers of items, focusing + only on balancing the sum of sequence lengths. + Returns: - partitions (List[List[int]]): - return k_partitions list containing the index of items. + List[List[int]]: A list containing k_partitions lists. Each inner list contains the + original indices of the items assigned to that partition. The indices + within each partition list are sorted. + + Raises: + AssertionError: If len(seqlen_list) < k_partitions. + AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. + AssertionError: If any resulting partition is empty. """ assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" @@ -212,7 +222,11 @@ def ceildiv(a, b): return -(a // -b) -def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_in_dp=True, min_num_micro_batch=None): +def roundup_divisible(a, b): + return ((a + b - 1) // b) * b + + +def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_divided_by=None, same_micro_num_in_dp=True, min_num_micro_batch=None): """ Split a batch into micro-batches by total token count, with optional DP sync and padding. @@ -220,6 +234,7 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_ batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly. max_token_len (int): max sum of attention_mask per micro-batch. dp_group (optional): torch.distributed group for data-parallel sync. + num_batches_divided_by (optional): virtual pipeline parallel size, for megatron. same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count. min_num_micro_batch (int, optional): force at least this many splits (pads empty ones). @@ -241,6 +256,8 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_ num_micro_batches = torch.tensor([num_micro_batches], device="cuda") dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item() + if num_batches_divided_by is not None: + num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) @@ -261,6 +278,15 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_ def get_reverse_idx(idx_map): + """ + Build the inverse of an index mapping. + + Args: + idx_map (Sequence[int]): Sequence where idx_map[i] = j. + + Returns: + List[int]: Inverse mapping list such that output[j] = i for each i. + """ reverse_idx_map = copy.deepcopy(idx_map) for i, idx in enumerate(idx_map): diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 9754d989344..e728758d49b 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -53,7 +53,20 @@ def gather_from_labels(data, label): def logprobs_from_logits(logits, labels, inplace_backward=True): """ + Compute per-token log-probabilities for the given labels. + + Uses a Flash-Attention–based cross-entropy (if available) for efficient backward, + otherwise falls back to a standard log-softmax+gather approach. + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + + Args: + logits (Tensor): Model outputs of shape (..., vocab_size). + labels (LongTensor): True class indices of shape matching logits[..., :-1]. + inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place. + + Returns: + Tensor: Log-probabilities of the target labels, shape logits.shape[:-1]. """ if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: batch_dim = logits.shape[:-1] @@ -121,7 +134,18 @@ def masked_sum(values, mask, axis=None): def masked_mean(values, mask, axis=None): - """Compute mean of tensor with a masked values.""" + """ + Compute the mean of `values` over elements selected by `mask`. + + Args: + values (Tensor): Input tensor. + mask (Tensor): Boolean or numeric mask of the same shape as `values`. + axis (int or tuple of int, optional): Dimension(s) along which to compute the mean. + Defaults to None (over all elements). + + Returns: + Tensor: Masked mean, with shape equal to `values` reduced over `axis`. + """ return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) @@ -144,7 +168,18 @@ def masked_var(values, mask, unbiased=True): def masked_whiten(values, mask, shift_mean=True): - """Whiten values with masked values.""" + """ + Whiten `values` by normalizing with mean and variance computed over `mask`. + + Args: + values (torch.Tensor): Input tensor. + mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. + shift_mean (bool): If True (default), output is zero-mean; + if False, the original mean is re-added after scaling. + + Returns: + torch.Tensor: Whitened tensor of same shape as `values`. + """ mean, var = masked_mean(values, mask), masked_var(values, mask) whitened = (values - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: @@ -472,6 +507,18 @@ def get_constant_schedule_with_warmup( num_warmup_steps: int, last_epoch: int = -1, ): + """ + Create a constant LR schedule with a linear warmup phase. + + Args: + optimizer (Optimizer): Wrapped optimizer. + num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value. + last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1. + + Returns: + LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant. + """ + def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index fbd62bc6199..b8573ec771a 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -23,6 +23,16 @@ class Tracking: + """A unified tracking interface for logging experiment data to multiple backends. + + This class provides a centralized way to log experiment metrics, parameters, and artifacts + to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. + + Attributes: + supported_backend: List of supported tracking backends. + logger: Dictionary of initialized logger instances for each backend. + """ + supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console", "clearml"] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): @@ -70,9 +80,9 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") if SWANLAB_API_KEY: swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten - + if config is None: - config = {} # make sure config is not None, otherwise **config will raise error + config = {} # make sure config is not None, otherwise **config will raise error swanlab.init( project=project_name, experiment_name=experiment_name, @@ -147,7 +157,7 @@ def __init__(self, project_name: str, experiment_name: str, config): output_uri=False, ) - self._task.connect_configuration(config, name='Hyperparameters') + self._task.connect_configuration(config, name="Hyperparameters") def _get_logger(self): return self._task.get_logger() @@ -159,7 +169,7 @@ def log(self, data, step): # logs = self._rewrite_logs(data) logger = self._get_logger() for k, v in data.items(): - title, series = k.split('/', 1) + title, series = k.split("/", 1) if isinstance(v, (int, float, np.floating, np.integer)): logger.report_scalar( @@ -176,12 +186,7 @@ def log(self, data, step): iteration=step, ) else: - logger.warning( - 'Trainer is attempting to log a value of ' - f'"{v}" of type {type(v)} for key "{k}". ' - "This invocation of ClearML logger's function " - 'is incorrect so this attribute was dropped. ' - ) + logger.warning(f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This invocation of ClearML logger\'s function is incorrect so this attribute was dropped. ') def finish(self): self._task.mark_completed() @@ -334,7 +339,7 @@ def log_generations_to_mlflow(self, samples, step): print(f"WARNING: save validation generation file to mlflow failed with error {e}") def log_generation_to_clearml(self, samples, step): - """ Log validation generation to clearml as table""" + """Log validation generation to clearml as table""" import clearml import pandas as pd diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py index bf587081d32..a33293364f1 100644 --- a/verl/utils/ulysses.py +++ b/verl/utils/ulysses.py @@ -242,6 +242,21 @@ def gather_outpus_and_unpad( grad_scaler: bool = True, group: Optional[dist.ProcessGroup] = None, ): + """ + Gather a tensor across a process group and optionally unpad its padded elements. + + Args: + x (Tensor): Input tensor to gather. + gather_dim (int): Dimension along which to gather across ranks. + unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding. + padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0. + grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. + group (ProcessGroup, optional): Process group for gathering. If None, uses + `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. + + Returns: + Tensor: The gathered tensor, with padding removed if requested. + """ group = get_ulysses_sequence_parallel_group() if group is None else group if group is None: return x diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 7ae5176bbc8..0f1c7c562bf 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -23,7 +23,6 @@ from typing import Tuple import torch -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -31,6 +30,7 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches @@ -38,6 +38,12 @@ from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor +if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +elif is_npu_available: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + + __all__ = ["DataParallelPPOActor"] logger = logging.getLogger(__file__) @@ -64,15 +70,7 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim if self.config.get("use_torch_compile", True) # use torch compile by default else verl_F.entropy_from_logits ) - - if self.use_fused_kernels: - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - self.fused_linear_for_ppo = FusedLinearForPPO() - - # FusedLinearForPPO has an error when compiled, disable for now - # if self.config.get("use_torch_compile", True): - # self.fused_linear_for_ppo.compile(dynamic=True) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -86,7 +84,7 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -124,29 +122,25 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + output = self.actor_module( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, **multi_modal_inputs, use_cache=False, + **extra_args, ) # prevent model thinks we are generating if self.use_fused_kernels: - hidden_states = output.last_hidden_state - vocab_weights = self.actor_module.lm_head.weight - - log_probs, entropy_rmpad = self.fused_linear_for_ppo( - hidden_states=hidden_states.squeeze(0), - vocab_weights=vocab_weights, - input_ids=input_ids_rmpad_rolled, - temperature=temperature, - ) + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) else: logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - - # logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad.div_(temperature) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) @@ -200,24 +194,21 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature output = self.actor_module( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **multi_modal_inputs, use_cache=False, + **extra_args, ) # prevent model thinks we are generating if self.use_fused_kernels: - hidden_states = output.last_hidden_state - vocab_weights = self.actor_module.lm_head.weight - - log_probs, entropy = self.fused_linear_for_ppo( - hidden_states=hidden_states[:, -response_length - 1 : -1, :], - vocab_weights=vocab_weights, - input_ids=micro_batch["responses"], - temperature=temperature, - ) + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) else: logits = output.logits @@ -359,9 +350,9 @@ def update_policy(self, data: DataProto): for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload + data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] @@ -410,7 +401,7 @@ def update_policy(self, data: DataProto): ref_log_prob = data["ref_log_prob"] # compute kl loss kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics["actor/kl_loss"] = kl_loss.detach().item() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 8a490fac8d4..cdabdea85d3 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -20,6 +20,7 @@ """ import copy +import itertools import logging import os from functools import partial @@ -44,7 +45,8 @@ from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits from verl.utils.megatron_utils import get_model_config from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import broadcast_dict_tensor, split_dict_tensor_into_batches +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import broadcast_dict_tensor from verl.workers.actor import BasePPOActor __all__ = ["MegatronPPOActor"] @@ -165,8 +167,15 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te DataProto: torch.Tensor: the log_prob tensor """ data.batch = data.batch.contiguous() - - def compute_logprobs_fn(output, data): + use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) + micro_batch_size = data.meta_info.get("micro_batch_size", None) + max_token_len = data.meta_info.get("max_token_len", None) + assert micro_batch_size is not None, "micro batch size is needed for forward compute" + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + max_token_len = max_token_len * self.config.megatron.context_parallel_size + + def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): response = data["responses"] response_length = response.size(1) log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous() @@ -185,14 +194,20 @@ def compute_logprobs_fn(output, data): response = batch["responses"] response_length = response.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy) + output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank if calculate_entropy: - log_probs = torch.cat([o[0]["log_probs"] for o in output], dim=0) # (bs, seq_size) + log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size) else: - log_probs = torch.cat([o["log_probs"] for o in output], dim=0) # (bs, seq_size) - log_probs = log_probs.to(torch.float32) + log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size) + log_probs = torch.cat(log_probs, dim=0).to(torch.float32) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + log_probs = log_probs[revert_indices] else: log_probs = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device) @@ -206,8 +221,14 @@ def compute_logprobs_fn(output, data): if calculate_entropy: # Note that o[0] is metrics, o[1] is entropy if mpu.is_pipeline_last_stage(ignore_virtual=True): - entropys = torch.cat([o[1] for o in output], dim=0) + entropys = torch.cat([o[1] for o in output["output"]], dim=0) entropys = entropys.to(torch.float32) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + entropys = entropys[revert_indices] else: entropys = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device) # broadcast across pp ranks @@ -256,7 +277,7 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: dataloader_kwargs={"shuffle": self.config.shuffle}, ) - def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None, calculate_entropy=False): + def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None, calculate_entropy=False, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None, mini_batch_size=None): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -264,18 +285,29 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce """ # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. - broadcast_dict_tensor(data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + mini_batch = data + broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) # split into micro-batches - data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) - - if data.meta_info.get("micro_batch_size", None) is not None: - batch_size = data.meta_info["micro_batch_size"] + mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + + indices = None + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size is not None and vpp_size > 1: + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + else: + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + total_seqlen = max_token_len else: - batch_size = self.config.ppo_micro_batch_size_per_gpu - batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size) + assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + micro_batches = mini_batch.batch.split(micro_batch_size) + seq_len = micro_batches[0]["input_ids"].shape[1] + total_seqlen = micro_batch_size * seq_len # compute input shapes for pp stages - n_micro_batch = len(batches) - seq_len = batches[0]["input_ids"].shape[1] + n_micro_batch = len(micro_batches) forward_backward_func = get_forward_backward_func() @@ -355,6 +387,7 @@ def loss_func(output, data, meta_info): "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), } ) + append_to_dict(metrics, stats) return policy_loss, [metrics, ret_entropy] @@ -375,12 +408,16 @@ def forward_step(batch_iter, model): def logits_processor(logits, label, label_mask): assert logits.shape[:2] == label.shape[:2] assert label.shape == label_mask.shape - log_probs = vocab_parallel_log_probs_from_logits(logits, label) - log_probs = log_probs.masked_fill(~label_mask, 0.0) - ret = {"log_probs": log_probs} + + ret = {} + if calculate_entropy: entropy = vocab_parallel_entropy(logits) ret["entropy"] = entropy + + log_probs = vocab_parallel_log_probs_from_logits(logits, label) + log_probs = log_probs.masked_fill(~label_mask, 0.0) + ret["log_probs"] = log_probs return ret logits_processor_args = {"label": label, "label_mask": label_mask} @@ -403,7 +440,7 @@ def logits_processor(logits, label, label_mask): return output, partial(loss_func, data=batch, meta_info=meta_info) # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module)) + batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) @@ -413,7 +450,7 @@ def logits_processor(logits, label, label_mask): data_iterator=batch_generator, model=self.actor_module, num_microbatches=n_micro_batch, - seq_length=batch_size * seq_len, # no use when input_shapes was set + seq_length=total_seqlen, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=forward_only, ) @@ -423,11 +460,14 @@ def logits_processor(logits, label, label_mask): data_iterator=batch_generator, model=self.actor_module, num_microbatches=n_micro_batch, - seq_length=batch_size * seq_len, # in use for pp = 1 + seq_length=total_seqlen, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=forward_only, ) # loss_reduces contains the stats returned from loss_func + losses_reduced = {"output": losses_reduced} + if use_dynamic_bsz: + losses_reduced["indices"] = indices return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) @@ -454,7 +494,15 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: chunk.zero_grad_buffer() calculate_entropy = self.config.entropy_coeff != 0 - metric_micro_batch = self.forward_backward_batch(data, calculate_entropy=calculate_entropy) + if data.meta_info.get("micro_batch_size", None) is not None: + micro_batch_size = data.meta_info["micro_batch_size"] + else: + micro_batch_size = self.config.ppo_micro_batch_size_per_gpu + max_token_len = None + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + metric_micro_batch = self.forward_backward_batch(data, calculate_entropy=calculate_entropy, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size) + metric_micro_batch = metric_micro_batch["output"] for metric in metric_micro_batch: # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index dc83028a620..7d1c85a72a2 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -34,8 +34,13 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic +from verl.utils.device import get_device_name, get_torch_device, is_npu_available, is_cuda_available -__all__ = ["DataParallelPPOCritic"] + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -50,6 +55,7 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt print(f"Critic use_remove_padding={self.use_remove_padding}") self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch["responses"].size(-1) @@ -58,7 +64,7 @@ def _forward_micro_batch(self, micro_batch): for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -207,9 +213,9 @@ def update_critic(self, data: DataProto): for data in micro_batches: # Support all devices if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload + data = data.to(get_torch_device().current_device()) # critic device is cpu when using offload responses = data["responses"] attention_mask = data["attention_mask"] values = data["values"] @@ -228,6 +234,7 @@ def update_critic(self, data: DataProto): returns=returns, response_mask=response_mask, cliprange_value=self.config.cliprange_value, + loss_agg_mode=self.config.loss_agg_mode, ) if self.config.use_dynamic_bsz: # relative to the dynamic bsz diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 68c1d51e889..2d5d4fc71e3 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -15,6 +15,7 @@ Implement a multiprocess PPOCritic """ +import itertools import logging import os from functools import partial @@ -33,7 +34,8 @@ from verl.utils.debug import GPUMemoryLogger from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean, split_dict_tensor_into_batches +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean from verl.workers.critic import BasePPOCritic logger = logging.getLogger(__file__) @@ -91,13 +93,26 @@ def compute_values(self, data: DataProto) -> DataProto: # data.batch = data.batch.to(self.critic_module.module.device) responses = data.batch["responses"] attention_mask = data.batch["attention_mask"] + use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) + micro_batch_size = data.meta_info.get("micro_batch_size", None) + max_token_len = data.meta_info.get("max_token_len", None) + assert micro_batch_size is not None, "micro batch size is needed for forward compute" + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + max_token_len = max_token_len * self.config.megatron.context_parallel_size response_length = responses.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data=data, forward_only=True) + output = self.forward_backward_batch(data=data, forward_only=True, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=None) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank - values = torch.cat([o["vpreds"] for o in output], dim=0) # (bs, seq_size, vocal_size) - values = values.to(torch.float32) + values = [o["vpreds"] for o in output["output"]] # (bs, seq_size, vocal_size) + values = torch.cat(values, dim=0).to(torch.float32) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + values = values[revert_indices] else: values = torch.empty_like(attention_mask, dtype=torch.float32) @@ -128,19 +143,37 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: dataloader_kwargs={"shuffle": self.config.shuffle}, ) - def forward_backward_batch(self, data: DataProto, forward_only=False): + def forward_backward_batch(self, data: DataProto, forward_only=False, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None, mini_batch_size=None): # broadcast from last pp rank to all other pp ranks - data.batch = data.batch.contiguous() - broadcast_dict_tensor(data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + mini_batch = data + mini_batch.batch = mini_batch.batch.contiguous() + broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) # split into micro-batches - data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) - batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size_per_gpu) - n_micro_batch = len(batches) - seq_len = batches[0]["input_ids"].shape[1] + mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + + indices = None + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size is not None and vpp_size > 1: + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + else: + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + total_seqlen = max_token_len + else: + assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + micro_batches = mini_batch.batch.split(micro_batch_size) + seq_len = micro_batches[0]["input_ids"].shape[1] + total_seqlen = micro_batch_size * seq_len + n_micro_batch = len(micro_batches) forward_backward_func = get_forward_backward_func() def loss_func(output, data, meta_info): + nonlocal use_dynamic_bsz + if forward_only: return torch.tensor(1.0, device=output.device), {"vpreds": output} @@ -163,7 +196,9 @@ def loss_func(output, data, meta_info): returns=returns, response_mask=response_mask, cliprange_value=cliprange_value, + loss_agg_mode=self.config.loss_agg_mode, ) + stats = { "critic/vf_loss": vf_loss.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(), @@ -193,7 +228,7 @@ def forward_step(batch_iter, model): return output, partial(loss_func, data=batch, meta_info={}) # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(batches, vpp_size=len(self.critic_module)) + batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module)) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) @@ -203,7 +238,7 @@ def forward_step(batch_iter, model): data_iterator=batch_generator, model=self.critic_module, num_microbatches=n_micro_batch, - seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # no use when input_shapes was set + seq_length=total_seqlen, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=forward_only, ) @@ -213,11 +248,14 @@ def forward_step(batch_iter, model): data_iterator=batch_generator, model=self.critic_module, num_microbatches=n_micro_batch, - seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # in use for pp = 1 + seq_length=total_seqlen, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=forward_only, ) # loss_reduces contains the stats returned from loss_func + losses_reduced = {"output": losses_reduced} + if use_dynamic_bsz: + losses_reduced["indices"] = indices return losses_reduced @GPUMemoryLogger("megatron critic", logger=logger) @@ -231,9 +269,16 @@ def update_critic(self, dataloader: Iterable[DataProto]): for chunk in self.critic_module: chunk.zero_grad_buffer() - metric_micro_batch = self.forward_backward_batch(data) - + micro_batch_size = self.config.ppo_micro_batch_size_per_gpu + max_token_len = None + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + metric_micro_batch = self.forward_backward_batch(data, forward_only=False, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size) + metric_micro_batch = metric_micro_batch["output"] update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() + learning_rate = self.critic_optimizer.param_groups[-1]["lr"] + data = {"critic/grad_norm": grad_norm, "critic/lr": learning_rate} + append_to_dict(metrics, data) if update_successful: # allgather already execute in optimizer.step in new megatron diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 48114746b67..99a671e14ff 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -33,8 +33,10 @@ from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( @@ -58,12 +60,14 @@ logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +device_name = get_device_name() + def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) return device_mesh @@ -93,7 +97,7 @@ def __init__(self, config: DictConfig, role: str): if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) - torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", rank=rank, world_size=world_size) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -105,7 +109,7 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -160,6 +164,7 @@ def _build_model_optimizer( trust_remote_code=False, use_liger=False, role="actor", + enable_activation_offload=False, ): from torch import optim from torch.distributed.fsdp import CPUOffload, MixedPrecision @@ -224,11 +229,15 @@ def _build_model_optimizer( _apply_liger_kernel_to_instance(model=actor_module) + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + apply_monkey_patch( model=actor_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 @@ -279,7 +288,7 @@ def _build_model_optimizer( param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -309,6 +318,9 @@ def _build_model_optimizer( else: raise NotImplementedError(f"not implement {fsdp_strategy}") + if enable_activation_offload: + enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config @@ -354,7 +366,7 @@ def _build_rollout(self, trust_remote_code=False): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -511,6 +523,7 @@ def init_model(self): trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), ) # get the original unwrapped module @@ -566,13 +579,13 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) @@ -583,8 +596,8 @@ def update_actor(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] @@ -609,7 +622,7 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) + prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout @@ -639,7 +652,7 @@ def generate_sequences(self, prompts: DataProto): output = output.to("cpu") # clear kv cache - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -649,7 +662,7 @@ def compute_log_prob(self, data: DataProto): load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -683,7 +696,7 @@ def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -740,7 +753,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -754,7 +767,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -871,7 +884,7 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -900,6 +913,10 @@ def _build_critic_model_optimizer(self, config): else: raise NotImplementedError(f"Unknown strategy {config.strategy}") + if config.model.get("enable_activation_offload", False): + enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) + enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) + log_gpu_memory_usage("After critic FSDP", logger=None) critic_optimizer = optim.AdamW( @@ -959,7 +976,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -982,11 +999,11 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -1056,7 +1073,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -1070,7 +1087,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -1116,10 +1133,15 @@ def _build_model(self, config): trust_remote_code=trust_remote_code, ) + fused_kernel_options = config.model.get("fused_kernel_options", None) + fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + apply_monkey_patch( model=reward_module, use_remove_padding=config.model.get("use_remove_padding", False), ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=self.config.model.get("use_fused_kernels", False), + fused_kernels_backend=fused_kernels_backend, ) reward_module.to(torch.bfloat16) @@ -1135,7 +1157,7 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -1164,11 +1186,14 @@ def init_model(self): self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + elif is_npu_available: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -1289,7 +1314,7 @@ def compute_rm_score(self, data: DataProto): from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: @@ -1304,7 +1329,7 @@ def compute_rm_score(self, data: DataProto): rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) + rm_data.batch = rm_data.batch.to(get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 8d89700d1a8..a0e806629e5 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -131,9 +131,11 @@ def __init__(self, config: DictConfig, role: str): self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False) self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False) elif self._is_ref: - if self.config.ref.get("ppo_micro_batch_size", None): + if self.config.ref.get("log_prob_micro_batch_size", None): self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.ppo_micro_batch_size + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + else: + assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time." self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): @@ -203,6 +205,8 @@ def megatron_actor_model_provider(pre_process, post_process): return actor_module, actor_optimizer, self.hf_config, optim_config def _build_rollout(self, trust_remote_code=False): + from torch.distributed.device_mesh import init_device_mesh + layer_name_mapping = { "qkv_layer_name": "self_attention.linear_qkv.", "gate_proj_layer_name": "linear_fc1.weight", @@ -264,9 +268,20 @@ def _build_rollout(self, trust_remote_code=False): # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager + infer_tp = self.config.rollout.tensor_model_parallel_size + dp = self.world_size // infer_tp + assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")) + local_path = copy_to_local(self.config.model.path) log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) - rollout = SGLangRollout(actor_module=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config) + rollout = SGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + trust_remote_code=self.config.model.get("trust_remote_code", False), + ) log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) from verl.models.mcore import get_mcore_weight_converter @@ -276,8 +291,50 @@ def _build_rollout(self, trust_remote_code=False): actor_module=self.actor.actor_module, inference_engine=rollout.inference_engine, model_config=self.actor_model_config, + transformer_config=self.tf_config, + layer_name_mapping=layer_name_mapping, + weight_converter=weight_converter, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage("After building sharding manager", logger=logger) + elif self.config.rollout.name == "sglang_async": + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. + # However, due to verl's setting, the main process of ray can not find any CUDA device, which would potentially lead to: + # "RuntimeError: No CUDA GPUs are available". + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. + # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 + from verl.workers.sharding_manager.megatron_sglang import MegatronAsyncSGLangShardingManager + + infer_tp = self.config.rollout.tensor_model_parallel_size + dp = self.world_size // infer_tp + assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")) + + local_path = copy_to_local(self.config.model.path) + log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) + rollout = AsyncSGLangRollout( + actor_module=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + trust_remote_code=trust_remote_code, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) + + from verl.models.mcore import get_mcore_weight_converter + + weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + sharding_manager = MegatronAsyncSGLangShardingManager( + actor_module=self.actor.actor_module, + inference_engine=rollout._engine, + model_config=self.actor_model_config, + transformer_config=self.tf_config, layer_name_mapping=layer_name_mapping, weight_converter=weight_converter, + device_mesh=rollout_device_mesh, ) log_gpu_memory_usage("After building sharding manager", logger=logger) else: @@ -434,7 +491,16 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage("After entering sharding manager", logger=logger) prompts = self.sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) + # output = self.rollout.generate_sequences(prompts=prompts) + if self.config.rollout.name == "sglang_async": + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + + if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0: + output = self.rollout.generate_sequences_with_tools(prompts=prompts) + else: + output = self.rollout.generate_sequences(prompts=prompts) + else: + output = self.rollout.generate_sequences(prompts=prompts) output = self.sharding_manager.postprocess_data(output) output = output.to("cpu") @@ -445,14 +511,16 @@ def generate_sequences(self, prompts: DataProto): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) def compute_ref_log_prob(self, data: DataProto): - data = data.to("cuda") assert self._is_ref if self._ref_is_offload_param: load_megatron_model_to_gpu(self.ref_module, load_grad=False) log_gpu_memory_usage("After load ref params and grad during compute_ref_log_prob", logger=logger) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature + data = data.to(torch.cuda.current_device()) output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = output.to("cpu") @@ -469,14 +537,14 @@ def compute_log_prob(self, data: DataProto): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module, load_grad=False) log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger) - data = data.to("cuda") - output = data # we should always recompute old_log_probs when it is HybridEngine - output.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - output.meta_info["temperature"] = self.config.rollout.temperature - old_log_probs, entropys = self.actor.compute_log_prob(data=output, calculate_entropy=True) - output.batch["old_log_probs"] = old_log_probs - output.batch["entropys"] = entropys + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + data = data.to(torch.cuda.current_device()) + output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + output = DataProto.from_dict(tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}) output = output.to("cpu") # clear kv cache if self._is_offload_param: @@ -653,7 +721,11 @@ def init_model(self): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_values(self, data: DataProto): - data = data.to("cuda") + micro_batch_size = self.config.ppo_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) values = self.critic.compute_values(data=data) @@ -665,7 +737,7 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def update_critic(self, data: DataProto): - data = data.to("cuda") + data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) @@ -837,7 +909,10 @@ def init_model(self): # the input_ids, responses, attention_mask and position_ids may be different! @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): - data.batch = data.batch.cuda() + data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + data = data.to(torch.cuda.current_device()) output = self.rm.compute_reward(data) output = output.to("cpu") return output diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index c320a42af77..399cdf05e09 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -17,7 +17,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score class DAPORewardManager: @@ -34,7 +34,7 @@ def __init__( ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key self.overlong_buffer_cfg = overlong_buffer_cfg self.max_resp_len = max_resp_len diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 3a59dc8b23a..59ad618c4c1 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -17,7 +17,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score class NaiveRewardManager: @@ -26,7 +26,7 @@ class NaiveRewardManager: def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key def __call__(self, data: DataProto, return_dict=False): diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index 1da30a252d3..f60e160e836 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -21,7 +21,7 @@ from transformers import PreTrainedTokenizer from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score import default_compute_score async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): @@ -108,7 +108,7 @@ def __init__( ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score + self.compute_score = compute_score or default_compute_score self.reward_fn_key = reward_fn_key def verify(self, data): diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 7948c213f0c..7ee3a47d514 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -15,6 +15,8 @@ Megatron Reward Model. """ +import itertools + import torch import torch.distributed from megatron.core import parallel_state as mpu @@ -23,7 +25,8 @@ from verl import DataProto from verl.utils.megatron.pipeline_parallel import make_batch_generator -from verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length, split_dict_tensor_into_batches +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length from verl.workers.reward_model.base import BasePPORewardModel @@ -128,15 +131,28 @@ def compute_reward(self, data: DataProto) -> DataProto: input_ids = data.batch["input_ids"] # (bs, seq_len') attention_mask = data.batch["attention_mask"] position_ids = data.batch["position_ids"] + use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) + micro_batch_size = data.meta_info.get("micro_batch_size", None) + max_token_len = data.meta_info.get("max_token_len", None) + assert micro_batch_size is not None, "micro batch size is needed for forward compute" + if use_dynamic_bsz: + assert max_token_len is not None, "use_dynamic_bsz is True, but max_token_len is None!" + max_token_len = max_token_len * self.config.megatron.context_parallel_size responses = data.batch["responses"] batch_size = responses.size(0) response_length = responses.size(1) with torch.no_grad(): - output = self.forward_batch(data) + output = self.forward_batch(data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len) if mpu.is_pipeline_last_stage(ignore_virtual=True): - logits = torch.cat(output, dim=0) + logits = torch.cat(output["output"], dim=0) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == logits.size(0), f"{len(indices)} vs. {logits.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + logits = logits[revert_indices] else: logits = torch.empty( (input_ids.shape[0], input_ids.shape[1]), @@ -184,7 +200,7 @@ def compute_reward(self, data: DataProto) -> DataProto: return DataProto(batch=batch) - def forward_batch(self, data: DataProto): + def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -192,19 +208,29 @@ def forward_batch(self, data: DataProto): """ # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. - data.batch = data.batch.contiguous() - broadcast_dict_tensor(data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) - - # split into micro-batches - if self.config is not None and "micro_batch_size_per_gpu" in self.config: - infer_batch_size = self.config.micro_batch_size_per_gpu + mini_batch = data + mini_batch.batch = mini_batch.batch.contiguous() + broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + + mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + + indices = None + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size is not None and vpp_size > 1: + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + else: + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + total_seqlen = max_token_len else: - infer_batch_size = data.batch.batch_size[0] - - data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) - batches = split_dict_tensor_into_batches(data.batch, batch_size=infer_batch_size) - n_micro_batch = len(batches) - seq_len = batches[0]["input_ids"].shape[1] + assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + micro_batches = mini_batch.batch.split(micro_batch_size) + seq_len = micro_batches[0]["input_ids"].shape[1] + total_seqlen = micro_batch_size * seq_len + n_micro_batch = len(micro_batches) # compute input shapes for pp stages forward_backward_func = get_forward_backward_func() @@ -233,7 +259,7 @@ def forward_step(batch_iter, model): return output, loss_func # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(batches, vpp_size=len(self.reward_model_module)) + batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module)) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) @@ -243,7 +269,7 @@ def forward_step(batch_iter, model): data_iterator=batch_generator, model=self.reward_model_module, num_microbatches=n_micro_batch, - seq_length=infer_batch_size * seq_len, # no use when input_shapes was set + seq_length=total_seqlen, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=True, ) @@ -253,12 +279,14 @@ def forward_step(batch_iter, model): data_iterator=batch_generator, model=self.reward_model_module, num_microbatches=n_micro_batch, - seq_length=infer_batch_size * seq_len, # in use for pp = 1 + seq_length=total_seqlen, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=True, ) # loss_reduces contains the stats returned from loss_func - + losses_reduced = {"output": losses_reduced} + if use_dynamic_bsz: + losses_reduced["indices"] = indices return losses_reduced def offload_params_to_cpu(self): diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py index f3c5d5cc448..a1fe27b7451 100644 --- a/verl/workers/rollout/async_server.py +++ b/verl/workers/rollout/async_server.py @@ -199,7 +199,8 @@ async def _chat_completions_openai(self, address: str, **chat_complete_request) async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: try: extra_headers = chat_complete_request.pop("extra_headers") - session = aiohttp.ClientSession() + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) async with session.post( url=f"http://{address}/v1/chat/completions", headers={"Authorization": "Bearer token-abc123", **extra_headers}, diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py index 7b04e51ef66..4cb6da68438 100644 --- a/verl/workers/rollout/hf_rollout.py +++ b/verl/workers/rollout/hf_rollout.py @@ -29,6 +29,7 @@ from verl import DataProto from verl.utils.torch_functional import get_response_mask +from verl.utils.device import get_torch_device from .base import BaseRollout @@ -109,6 +110,7 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: output = self.module.generate( input_ids=idx, attention_mask=attention_mask, + position_ids=position_ids, do_sample=do_sample, max_new_tokens=response_length, eos_token_id=eos_token_id, @@ -165,7 +167,7 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: ) # empty cache before compute old_log_prob - torch.cuda.empty_cache() + get_torch_device().empty_cache() self.module.train() return DataProto(batch=batch) diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py index f43cfe02405..145fa7cf143 100644 --- a/verl/workers/rollout/schemas.py +++ b/verl/workers/rollout/schemas.py @@ -103,12 +103,12 @@ class AsyncRolloutRequest(BaseModel): } } - def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> str: + def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]: return tokenizer.apply_chat_template( # type: ignore conversation=[msg.model_dump() for msg in self.messages], tools=[tool.model_dump() for tool in self.tools] if self.tools else None, add_generation_prompt=True, - tokenize=False, + tokenize=True, ) def add_assistant_message( diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index ada5eb3d2a5..b501305e1a0 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -365,7 +365,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): - print(f"{self.sampling_params=}") + # print(f"{self.sampling_params=}") if self._tp_rank == 0: loop = asyncio.get_event_loop() output = loop.run_until_complete( @@ -390,11 +390,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: out = _post_process_outputs(self.tokenizer, output) response = out[0].to(idx.device) - # log_probs = out[1].to(idx.device) + rollout_log_probs = out[1].to(idx.device) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.get("n", 1) > 1 and do_sample: @@ -428,7 +428,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, @@ -482,7 +482,11 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo else: raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: - generation_prompt = _req.get_generation_prompt(self.tokenizer) + generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) + max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) + if max_new_tokens <= 0: + finish_reason_type = FinishReasonTypeEnum.STOP + break if not do_sample: kwargs = dict( n=1, @@ -494,7 +498,6 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo top_k=-1, ignore_eos=False, min_new_tokens=0, - max_new_tokens=self.config.response_length, skip_special_tokens=True, spaces_between_special_tokens=True, ) @@ -506,12 +509,13 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo "temperature": self.config.val_kwargs.temperature, "n": 1, # if validate, already repeat in ray_trainer } + kwargs["max_new_tokens"] = max_new_tokens if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess kwargs["n"] = 1 # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): output = await self._engine.async_generate( - prompt=generation_prompt, + input_ids=generation_prompt_ids, sampling_params=self.sampling_params, return_logprob=False, ) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index ed852f769f0..af30a568df5 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -307,7 +307,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): - print(f"{self.sampling_params=}") + # print(f"{self.sampling_params=}") output = self.inference_engine.generate( prompt=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 37a39a5ee82..06817b5d50f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -229,11 +229,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = output[0].to(idx.device) - # log_probs = output[1].to(idx.device) + log_probs = output[1].to(idx.device) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.n > 1 and do_sample: @@ -262,7 +262,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index e8ae44437dd..e6a162ef791 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -282,11 +282,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = [] + rollout_log_probs = [] for output in outputs: for sample_id in range(len(output.outputs)): - response.append(output.outputs[sample_id].token_ids) + response_ids = output.outputs[sample_id].token_ids + response.append(response_ids) + curr_log_prob = [] + for i, logprob in enumerate(output.outputs[sample_id].logprobs): + curr_log_prob.append(logprob[response_ids[i]].logprob) + rollout_log_probs.append(curr_log_prob) response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) + rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device) + rollout_log_probs = rollout_log_probs.to(torch.float32) if self.sampling_params.n > 1 and do_sample: idx = _repeat_interleave(idx, self.sampling_params.n) @@ -322,7 +330,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 114b9646fb2..133db24ef84 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -35,6 +35,7 @@ from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import check_cuda_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from verl.utils.device import get_torch_device from .base import BaseShardingManager @@ -84,26 +85,26 @@ def __init__( self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __enter__(self): - # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and + # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - torch.cuda.empty_cache() + get_torch_device().empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: @@ -132,7 +133,7 @@ def __enter__(self): del params if self.offload_param: offload_fsdp_model_to_cpu(self.module) - torch.cuda.empty_cache() + get_torch_device().empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) @@ -141,8 +142,8 @@ def __enter__(self): # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): @@ -158,12 +159,12 @@ def __exit__(self, exc_type, exc_value, traceback): self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: @@ -194,6 +195,6 @@ def postprocess_data(self, data: DataProto) -> DataProto: def update_params(self, updated_params): model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) - device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy + device = get_torch_device().current_device() # used when fsdp2 set cpu_offload_policy loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items())) logger.info("vLLM load weights, loaded_params: %d", len(loaded_params)) diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index 817867a5a49..0a1352d9e74 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +21,21 @@ import os import torch +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.verl_engine import VerlEngine from torch import nn +from torch.distributed.device_mesh import DeviceMesh + +from verl.protocol import DataProto, all_gather_data_proto +from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.megatron_utils import per_tensor_generator -from verl.utils.debug import log_gpu_memory_usage +from .base import BaseShardingManager logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) + + """ Megatron Hybrid Engine: - During training, only the current pp stage holds the parameters @@ -34,61 +45,142 @@ - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -import torch.distributed -from sglang.srt.entrypoints.verl_engine import VerlEngine -from torch.distributed import new_group - -from verl.utils.debug import GPUMemoryLogger -from verl.utils.megatron_utils import per_tensor_generator - -from .base import BaseShardingManager - -_MICRO_DATA_PARALLEL_GROUP = None - class MegatronSGLangShardingManager(BaseShardingManager): - - def __init__(self, actor_module: nn.ModuleList, inference_engine: VerlEngine, model_config, layer_name_mapping, weight_converter): - from megatron.core import parallel_state as mpu + def __init__( + self, + actor_module: nn.ModuleList, + inference_engine: VerlEngine, + model_config, + transformer_config, + layer_name_mapping, + weight_converter, + device_mesh: DeviceMesh | None = None, + ): self.actor_module = actor_module self.inference_engine = inference_engine self.model_config = model_config + self.transformer_config = transformer_config self.layer_name_mapping = layer_name_mapping self.weight_converter = weight_converter - global _MICRO_DATA_PARALLEL_GROUP - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - self.infer_tp_size = self.inference_engine._tp_size - self.train_tp_size = mpu.get_tensor_model_parallel_world_size() - self.need_tp_reshard = self.infer_tp_size == self.train_tp_size - - assert self.infer_tp_size <= self.train_tp_size, \ - 'Not implemented for infer_tp > train_tp' - assert self.train_tp_size % self.infer_tp_size == 0 - - micro_dp_size = self.train_tp_size // self.infer_tp_size - num_micro_dp_groups = world_size // micro_dp_size - assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized") - for i in range(num_micro_dp_groups): - ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size) - group = new_group(ranks=ranks) - if rank in ranks: - _MICRO_DATA_PARALLEL_GROUP = group + self.device_mesh = device_mesh + + if self.device_mesh is not None: + self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0] + else: + self.infer_tp_size = self.inference_engine._tp_size + + # Note that torch_random_states may be different on each dp rank + self.torch_random_states = torch.cuda.get_rng_state() + # get a random rng states + if self.device_mesh is not None: + gen_dp_rank = self.device_mesh["dp"].get_local_rank() + torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) + else: + self.gen_random_states = None @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) def __enter__(self): - per_tensor_param = per_tensor_generator(self.actor_module, self.model_config, self.weight_converter, self.layer_name_mapping) - self.inference_engine.resume_memory_occupation() - self.inference_engine.update_weights_from_tensor(per_tensor_param, load_format=None) + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) + self.update_weights(per_tensor_param) + + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger) - self.inference_engine.release_memory_occupation() - log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger) + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) for model in self.actor_module: model.train() # add empty cache after each compute torch.cuda.empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) + + def update_weights(self, params): + self.inference_engine.resume_memory_occupation() + self.inference_engine.update_weights_from_tensor(params, load_format=None) + + def release_memory(self): + self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) + def preprocess_data(self, data: DataProto) -> DataProto: + # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp + if self.infer_tp_size == 1: + return data + all_gather_data_proto(data, self.device_mesh["tp"].get_group()) + return data + + @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) + def postprocess_data(self, data: DataProto) -> DataProto: + # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp + if self.infer_tp_size == 1: + return data + return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()] + + +class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager): + def __init__( + self, + actor_module: nn.ModuleList, + inference_engine: Engine, + model_config, + transformer_config, + layer_name_mapping, + weight_converter, + device_mesh: DeviceMesh = None, + ): + super().__init__( + actor_module, + inference_engine, + model_config, + transformer_config, + layer_name_mapping, + weight_converter, + device_mesh, + ) + + def update_weights(self, params): + if self.device_mesh["tp"].get_local_rank() == 0: + self.inference_engine.resume_memory_occupation() + + # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update + # named_tensors = [(k, v) for k, v in params.items()] + named_tensors = params + load_format = None + for tensor_index, (name, tensor) in enumerate(named_tensors): + if self.device_mesh["tp"].get_local_rank() == 0: + self.inference_engine.update_weights_from_tensor( + named_tensors=[ + ( + name, + tensor.detach(), + ) + ], + load_format=load_format, + flush_cache=False, + ) + + if self.device_mesh["tp"].get_local_rank() == 0: + self.inference_engine.flush_cache() + + def release_memory(self): + if self.device_mesh["tp"].get_local_rank() == 0: + self.inference_engine.release_memory_occupation() diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index 3c276712326..a7568958c05 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -28,7 +28,6 @@ from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -import verl.utils.megatron.tensor_parallel as tp_utils from verl import DataProto from verl.models.mcore.weight_converter import McoreToHFWeightConverterBase from verl.protocol import all_gather_data_proto @@ -36,10 +35,8 @@ from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger from verl.utils.megatron_utils import ( - broadcast_from_megatron_pp, - broadcast_str_from_megatron_pp, - convert_megatron_model_to_transformers_model, get_model, + per_tensor_generator, unwrap_model, ) from verl.utils.memory_buffer import ( @@ -47,7 +44,6 @@ build_memory_reference_from_module, get_weight_buffer_meta_from_module, ) -from verl.utils.model import normalize_model_name from verl.utils.torch_functional import check_cuda_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader @@ -308,218 +304,13 @@ def __init__( self.need_tp_reshard = self.train_tp_size != self.infer_tp_size self.train_tp_larger = self.train_tp_size > self.infer_tp_size - def per_tensor_generator(self, convert_qkv_gate_up_by_simple_split=True): - """ - convert_qkv_gate_up_by_simple_split is a parameter affected by the vLLM version. - """ - from megatron.core import parallel_state as mpu - - pp_rank = mpu.get_pipeline_model_parallel_rank() - vpp_size = len(self.actor_module) - - all_gather_group = self.train_tp_group - all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) - - def tensor_generator(): - for scan_vpp_idx in range(vpp_size): - yield from self.actor_module[scan_vpp_idx].named_parameters() - - # we need first make all rank get full model information - meta_info = [] - for scan_vpp_idx in range(vpp_size): - for idx, (name, _) in enumerate(self.actor_module[scan_vpp_idx].named_parameters()): - meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - - obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()) - layer_list_meta = [item for sublist in obj_spec_output for item in sublist] - - gen_func = tensor_generator() - - # lazy load tensor for full model - for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: - if self.model_config.tie_word_embeddings and ("output_layers" in name): - import warnings - - warnings.warn("Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2) - continue - if cur_pp_rank == pp_rank: - try: - cur_name, cur_tensor = next(gen_func) - except StopIteration: - cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, self.transformer_config) - else: - cur_tensor, cur_name = None, None - - # pp broadcast model tensor and name - cur_name = broadcast_str_from_megatron_pp(cur_name) - broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor) - - # (xya): this is a hack to fix the name of the parameters - while cur_name.startswith("module."): - cur_name = cur_name[len("module.") :] - - # EP - if ".mlp.experts.linear_fc" in cur_name and self.train_ep_size > 1: - num_experts = self.weight_converter.mcore_config.num_moe_experts - num_experts_per_rank = num_experts // self.train_ep_size - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(self.train_ep_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=self.train_ep_group) - - name_prefix, local_expert_id = cur_name.split(".weight") - local_expert_id = int(local_expert_id) - global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(self.train_ep_size)] - global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] - - for name, param in zip(global_expert_names, infer_params): - if self.train_etp_size > 1: - # gather etp - etp_params = [torch.empty_like(param) for _ in range(self.train_etp_size)] - torch.distributed.all_gather(etp_params, param, group=self.train_etp_group) - params = etp_params - else: - params = [param] - - merge_params = self.default_tp_concat_fn(name, broad_pp_tensor, params, self.model_config, convert_qkv_gate_up_by_simple_split) - if not isinstance(merge_params, list): - merge_params = [merge_params] - converted_names, converted_params = self.weight_converter.convert_param(name, merge_params) - - yield from zip(converted_names, converted_params) - continue - - # tp all gather - if tp_utils.is_tensor_parallel_param(broad_pp_tensor): - # allocate a new tensor with proper size - if all_gather_group_size <= 1: - infer_params = [broad_pp_tensor] - else: - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) - infer_params = self.default_tp_concat_fn(cur_name, broad_pp_tensor, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split) - else: - infer_params = broad_pp_tensor - - if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): - converted_names, converted_params = convert_megatron_model_to_transformers_model( - cur_name, - infer_params, - self.model_config, - self.train_tp_size, - 0, # no impact - convert_qkv_gate_up_by_trunk_concat=False, - ) # defualt false - else: - if not isinstance(infer_params, list): - infer_params = [infer_params] - converted_names, converted_params = self.weight_converter.convert_param(cur_name, infer_params) - - yield from zip(converted_names, converted_params) - - def default_tp_concat_fn(self, name, param, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False): - """ - name: name of the parameter - param: training parameters - infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered - from train tp group (vllm 0.8.2) or micro-dp group (vllm <= 0.6.3) - model_config: huggingface model_config - TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model - definition so that it is model-agnostic. If the model doesn't implement this function, - we can throw an error to force user disable TP HybridEngine. - """ - if self.layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 - num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" - kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - for infer_param in infer_params: - num_query_groups_per_partition = model_config.num_key_value_heads // self.train_tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] - - elif self.layer_name_mapping.get("gate_proj_layer_name") in name: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in infer_params: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] - - elif "mlp.experts.linear_fc2.weight" in name: # moe - infer_params = torch.cat(infer_params, dim=1) - - else: - # concat tensor - infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param)) - - return infer_params - - def _post_process_params(self, params, convert_qkv_gate_up_by_simple_split=False): - """ - For each param, if it is a tp-splited param, we all-gather from train tp group - """ - # here the params are in train tp format. we iterate params and all-gather - # TODO(zhangchi.usc1992) We can consider copy non-tp weight to another infer buffer. - # In this way, all the params in the original memory_buffers and can be offload. - all_gather_group = self.train_tp_group - all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) - - for name, param in params: - if tp_utils.is_tensor_parallel_param(param): - # allocate a new tensor with proper size - if all_gather_group_size <= 1: - infer_params = [param] - else: - infer_params = [torch.empty_like(param) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, param, group=all_gather_group) - infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config, convert_qkv_gate_up_by_simple_split) - else: - infer_params = param - if vllm_version in ("0.4.2", "0.5.4", "0.6.3"): - converted_names, converted_params = convert_megatron_model_to_transformers_model( - name, - infer_params, - self.model_config, - self.train_tp_size, - self.module.pp_models[0][0].config.num_query_groups, - convert_qkv_gate_up_by_trunk_concat=False, - ) - else: - if not isinstance(infer_params, list): - infer_params = [infer_params] - converted_names, converted_params = self.weight_converter.convert_param(name, infer_params) - yield from zip(converted_names, converted_params) - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) def __enter__(self): if vllm_version in ( "0.5.4", "0.6.3", ): - per_tensor_param = self.per_tensor_generator(convert_qkv_gate_up_by_simple_split=False) + per_tensor_param = per_tensor_generator(self.actor_module, self.model_config, self.weight_converter, self.transformer_config, self.layer_name_mapping, convert_qkv_gate_up_by_simple_split=False) self.inference_engine.sync_model_weights(per_tensor_param, load_format="megatron") else: # > 0.7.2 @@ -527,7 +318,13 @@ def __enter__(self): self.inference_engine.wake_up(tags=["weights"]) else: self.inference_engine.wake_up() - per_tensor_param = self.per_tensor_generator() + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model patch_vllm_moe_model_weight_loader(model) loaded_params = model.load_weights(per_tensor_param)