diff --git a/.github/workflows/checkpoint_converter.yml b/.github/workflows/checkpoint_converter.yml deleted file mode 100644 index 4820497f79c..00000000000 --- a/.github/workflows/checkpoint_converter.yml +++ /dev/null @@ -1,175 +0,0 @@ -# # Tests layout - -# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -# - `tests/trainer` for testing functionality related to `verl/trainer` -# - `tests/models` for testing functionality related to `verl/models` -# - ... - -# There are a few folders with `special_` prefix, created for special purposes: -# - `special_distributed`: unit tests that must run with multiple GPUs -# - `special_e2e`: end-to-end tests with training/generation scripts -# - `special_npu`: tests for NPUs -# - `special_sanity`: a suite of quick sanity tests -# - `special_standalone`: a set of test that are designed to run in dedicated environments - -# Accelerators for tests -# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# # Workflow layout - -# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -# 3. End-to-end tests: `e2e_*.yml` -# 4. Unit tests -# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` -# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. -# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when -# - new workflow yaml is added to `.github/workflows` -# - new tests are added to workflow mentioned in 2. - -name: checkpoint_converter -# latest version: Megatron-LM core_v0.14.0 https://github.com/NVIDIA/Megatron-LM/tree/core_v0.14.0 - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Recipes - - "!recipe/**" - # FSDP - - "!verl/workers/**/*dp_*.py" - # Entrypoints - - ".github/workflows/checkpoint_converter.yml" - - ".github/workflows/e2e_ppo_trainer_megatron.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/special_e2e/run_ppo_trainer_megatron.sh" - - "verl/trainer/main_ppo.py" - - "verl/trainer/config/ppo_megatron_trainer.yaml" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -env: - IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:sgl055.dev2" - DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" - -jobs: - setup: - if: github.repository_owner == 'volcengine' - runs-on: ubuntu-latest - outputs: - runner-label: ${{ steps.create-runner.outputs.runner-label }} - mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} - steps: - - uses: actions/checkout@v4 - - id: create-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "create" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-image: "${{ env.IMAGE }}" - - checkpoint_converter: - needs: setup - runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ] - timeout-minutes: 20 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test] -# - name: Download Model to Use -# run: | -# huggingface-cli download Qwen/Qwen2.5-0.5B --local-dir ${HOME}/models/Qwen/Qwen2.5-0.5B -# huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct --local-dir ${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct -# export HF_HUB_OFFLINE=1 - - name: Running Huggingface to Megatron dist_ckpt converter (Qwen/Qwen2.5-0.5B) - run: | - ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/Qwen/Qwen2.5-0.5B --test - - name: Running Huggingface to Megatron dist_ckpt converter (deepseek-ai/deepseek-coder-1.3b-instruct) - run: | - ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct --output_path checkpoints/deepseek-ai/deepseek-coder-1.3b-instruct --test - - name: Clean up - run: | - rm -rf checkpoints - - checkpoint_converter_large_moe_models: - needs: setup - runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ] - timeout-minutes: 30 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - HF_ENDPOINT: "https://hf-mirror.com" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test] -# - name: Download Model to Use -# run: | -# huggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat --local-dir ${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat -# export HF_HUB_OFFLINE=1 - - name: Running Huggingface to Megatron dist_ckpt CPU converter (Qwen/Qwen1.5-MoE-A2.7B-Chat) - run: | - ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat --use_cpu_initialization - - name: Running distributed Huggingface to Megatron dist_ckpt CPU converter (Qwen/Qwen1.5-MoE-A2.7B-Chat) - run: | - ray stop --force - torchrun --nproc_per_node 8 --nnodes 1 scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat_dist --use_cpu_initialization - - name: clean up - run: | - rm -rf checkpoints - - cleanup: - runs-on: ubuntu-latest - needs: - [ - setup, - checkpoint_converter, - checkpoint_converter_large_moe_models - ] - if: always() - steps: - - id: destroy-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "destroy" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" \ No newline at end of file diff --git a/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml b/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml index ccdc7c9c15d..4b9addffa2e 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml @@ -181,11 +181,6 @@ jobs: run: | ray stop --force ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 LR_WARMUP_STEPS=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) - run: | - exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with FP8 rollout run: | ray stop --force diff --git a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml index ccc503b0d58..9d0d92dfa97 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml @@ -125,13 +125,10 @@ jobs: - name: Prepare Geo3k dataset run: | python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/hiyouga/geometry3k/ - - name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt - run: | - python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron - name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force - ENGINE=sglang ROLLOUT_MODE=async TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=true DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/special_e2e/run_ppo_trainer_megatron.sh + ENGINE=sglang ROLLOUT_MODE=async TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=false DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints diff --git a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml index 35a50aae0ea..deaaae67cf5 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml @@ -132,12 +132,12 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use mbridge e2e to pre-load and save (Deepseek) run: | ray stop --force - ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 USE_MBRIDGE=True USE_DIST_CKPT=False \ + ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 USE_DIST_CKPT=False \ bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use mbridge e2e to pre-load and save (Deepseek) run: | ray stop --force - RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 USE_MBRIDGE=True USE_DIST_CKPT=False \ + RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 USE_DIST_CKPT=False \ bash tests/special_e2e/run_ppo_trainer_megatron.sh # LoRA training save&load - name: clean up and install Megatron-Bridge @@ -149,12 +149,12 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use Megatron-Bridge LoRA e2e to pre-load and save (Deepseek) run: | ray stop --force - ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=4 LORA_RANK=8 COMMON_VPP=null COMMON_CP=1 USE_MBRIDGE=True VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False USE_DIST_CKPT=False \ + ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=4 LORA_RANK=8 COMMON_VPP=null COMMON_CP=1 VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False USE_DIST_CKPT=False \ bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use Megatron-Bridge LoRA e2e to pre-load and save (Deepseek) run: | ray stop --force - RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 SAVE_FREQ=1 COMMON_PP=4 LORA_RANK=8 COMMON_VPP=null COMMON_CP=1 USE_MBRIDGE=True VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False USE_DIST_CKPT=False \ + RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 SAVE_FREQ=1 COMMON_PP=4 LORA_RANK=8 COMMON_VPP=null COMMON_CP=1 VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False USE_DIST_CKPT=False \ bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up run: | @@ -186,11 +186,6 @@ jobs: run: | ray stop --force ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 LR_WARMUP_STEPS=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) - run: | - exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with FP8 rollout run: | ray stop --force diff --git a/.github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml b/.github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml index fb3e73ed02d..03e646400f8 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml @@ -134,7 +134,7 @@ jobs: ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \ PPO_MAX_TOKEN_LEN=1024 FWD_MAX_TOKEN_LEN=1024 \ MAX_PROMPT_LENGTH=512 MAX_RESPONSE_LENGTH=512 \ - MODEL_ID=Qwen/Qwen3-30B-A3B-Instruct-2507 USE_MBRIDGE=True VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False \ + MODEL_ID=Qwen/Qwen3-30B-A3B-Instruct-2507 VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False \ COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=4 COMMON_ETP=1 INFER_TP=8 \ USE_DIST_CKPT=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up @@ -146,7 +146,7 @@ jobs: ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \ PPO_MAX_TOKEN_LEN=1024 FWD_MAX_TOKEN_LEN=1024 \ MAX_PROMPT_LENGTH=512 MAX_RESPONSE_LENGTH=512 LORA_RANK=8 CRITIC_LORA_RANK=8 \ - MODEL_ID=Qwen/Qwen3-30B-A3B-Instruct-2507 USE_MBRIDGE=True VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False \ + MODEL_ID=Qwen/Qwen3-30B-A3B-Instruct-2507 VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False \ COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=2 COMMON_ETP=1 INFER_TP=8 \ USE_DIST_CKPT=False ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up @@ -174,16 +174,13 @@ jobs: - name: Prepare Geo3k dataset run: | python3 examples/data_preprocess/geo3k.py --local_dataset_path ${HOME}/models/hf_data/hiyouga/geometry3k/ - - name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt - run: | - python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron - name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet \ MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo \ USE_DYNAMIC_BSZ=False USE_FUSED_KERNELS=True SKIP_SAVE_HF_MODEL=1 \ - COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=true \ + COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=false \ DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up run: | diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index cab35a68d96..c9f1f2deac2 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -48,7 +48,6 @@ on: # Entrypoints - ".github/workflows/model.yml" - "tests/special_distributed/test_fsdp_ckpt.py" - - "tests/special_distributed/test_mcore_config_converter.py" - "tests/special_distributed/test_tensor_dict.py" - "tests/models/**" - "tests/special_distributed/run_all.sh" @@ -144,34 +143,6 @@ jobs: run: | STRATEGY=fsdp2 torchrun --nproc_per_node=8 tests/special_distributed/test_fsdp_ckpt.py - mcore_config_converter: - needs: setup - runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ] - timeout-minutes: 20 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test] -# - name: Download model config files -# run: | -# hf download Qwen/Qwen2.5-7B config.json --local-dir $HOME/configs/Qwen/Qwen2.5-7B -# hf download Qwen/Qwen3-8B config.json --local-dir $HOME/configs/Qwen/Qwen3-8B -# hf download deepseek-ai/deepseek-coder-1.3b-instruct config.json --local-dir $HOME/configs/deepseek-ai/deepseek-coder-1.3b-instruct -# hf download Qwen/Qwen2-57B-A14B config.json --local-dir $HOME/configs/Qwen/Qwen2-57B-A14B -# hf download Qwen/Qwen3-30B-A3B config.json --local-dir $HOME/configs/Qwen/Qwen3-30B-A3B -# hf download deepseek-ai/DeepSeek-V3-Base config.json --local-dir $HOME/configs/deepseek-ai/DeepSeek-V3-Base - - name: Running mcore config converter tests on 8 L20 GPUs - run: | - torchrun --nproc_per_node=8 tests/special_distributed/test_mcore_config_converter.py model_engine: needs: setup @@ -206,7 +177,6 @@ jobs: setup, model_rmpad, model_rmpad_fsdp2_unstable, - mcore_config_converter, model_engine ] if: always() diff --git a/docs/advance/checkpoint.rst b/docs/advance/checkpoint.rst index 56bec4a75c3..9782af951d9 100644 --- a/docs/advance/checkpoint.rst +++ b/docs/advance/checkpoint.rst @@ -137,32 +137,8 @@ Current implementation use solution 2. HuggingFace to Megatron DistCheckpoint details ---------------------------------------------- -If your model is quite huge, we recommend you to use Megatron dist-checkpoint to load the model. -Megatron dist-checkpoint supports loading with different kinds of model parallelism, -and it is much faster than the original checkpoint loading. - -To convert original HuggingFace model to Megatron dist-checkpoint, -you can use the ``scripts/converter_hf_to_mcore.py`` script. Large MoE models are temporarily supported with CPU initialization, -which is a little slower. While we are working on a better solution to support large models. - -Example command to convert the model is as follows: - -.. code:: bash - - python scripts/converter_hf_to_mcore.py \ - --hf_model_path Qwen/Qwen1.5-MoE-A2.7B-Chat \ - --output_path /mnt/disk/Qwen/Qwen1.5-MoE-A2.7B-Chat \ - --use_cpu_initialization # Only work for MoE models - - -Example command to distributed convert the huge model like deepseekv3 671B is as follows: - -.. code:: bash - - torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \ - --hf_model_path deepseek-ai/DeepSeek-V3 \ - --output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \ - --use_cpu_initialization # Only work for MoE models +Through ``mbridge``, we can directly save the mcore model to huggingface format during training. +No need to convert the model to Megatron dist-checkpoint format. Original Checkpoint Utils ------------------------- diff --git a/docs/perf/best_practices.rst b/docs/perf/best_practices.rst index d7ff382c250..83452c06977 100644 --- a/docs/perf/best_practices.rst +++ b/docs/perf/best_practices.rst @@ -108,8 +108,6 @@ Parameter Reference :math:`\theta` - ``actor_rollout_ref.model.path``: Path to the actor checkpoint in HuggingFace-compatible format. - - ``actor_rollout_ref.actor.megatron.use_mbridge``: - Enable mbridge format conversion when the model was trained with Megatron. Use the latest mbridge release: https://github.com/ISEEKYAN/mbridge. :math:`\pi` - ``actor_rollout_ref.rollout.name``: diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh index ae55caf9d30..e1b90da0eff 100644 --- a/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh +++ b/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh @@ -112,7 +112,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ trainer.default_local_dir=$CKPT_DIR \ trainer.val_before_train=False \ trainer.total_epochs=100 $@ diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh index ede8eeda79f..85548fb5a9f 100644 --- a/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh +++ b/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh @@ -91,7 +91,6 @@ python3 -m verl.trainer.main_ppo \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh index 5344cfd9aa6..6c0f5ba0b68 100644 --- a/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh +++ b/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh @@ -50,7 +50,6 @@ ACTOR=( actor_rollout_ref.actor.ppo_mini_batch_size=16 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 actor_rollout_ref.actor.use_dynamic_bsz=True - actor_rollout_ref.actor.megatron.use_mbridge=True actor_rollout_ref.actor.megatron.vanilla_mbridge=False actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh index 632bdc8fa1e..16f025073c3 100644 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh @@ -5,8 +5,6 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation ov HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct DIST_CKPT_PATH=${DIST_CKPT_PATH} -# convert HF model to meagatron format offlinely -# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH # megatron tuning guide: diff --git a/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh b/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh index 0d3b855b6a9..b94c9998765 100644 --- a/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh +++ b/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh @@ -93,7 +93,6 @@ python3 -m verl.trainer.main_ppo \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ diff --git a/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh b/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh index 7dfc197f214..5a2080216b5 100644 --- a/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh +++ b/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh @@ -53,7 +53,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.megatron.param_offload=True \ actor_rollout_ref.actor.megatron.optimizer_offload=True \ actor_rollout_ref.actor.megatron.grad_offload=True \ diff --git a/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh b/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh index 4c5b2de24f7..f6f2af85358 100644 --- a/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh +++ b/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh @@ -53,7 +53,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.megatron.param_offload=True \ actor_rollout_ref.actor.megatron.optimizer_offload=True \ actor_rollout_ref.actor.megatron.grad_offload=True \ diff --git a/examples/grpo_trainer/run_qwen3_vl-8b-megatron.sh b/examples/grpo_trainer/run_qwen3_vl-8b-megatron.sh index 69739c2d512..c3ae05df0cc 100644 --- a/examples/grpo_trainer/run_qwen3_vl-8b-megatron.sh +++ b/examples/grpo_trainer/run_qwen3_vl-8b-megatron.sh @@ -57,7 +57,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.megatron.param_offload=True \ actor_rollout_ref.actor.megatron.optimizer_offload=True \ actor_rollout_ref.actor.megatron.grad_offload=True \ diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh index 1db311e28f2..d0acf53fd65 100644 --- a/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh +++ b/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh @@ -89,7 +89,6 @@ RM_ETP=${RM_ETP:-$COMMON_ETP} # install mbridge # pip3 install git+https://github.com/ISEEKYAN/mbridge -USE_MBRIDGE=True USE_DIST_CKPT=False python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ @@ -124,7 +123,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ - actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ actor_rollout_ref.actor.megatron.param_offload=${offload} \ actor_rollout_ref.actor.megatron.grad_offload=${offload} \ diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh index 77805cdfb3b..8b20349e115 100644 --- a/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh +++ b/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh @@ -50,7 +50,6 @@ ACTOR=( actor_rollout_ref.actor.optim.lr=3e-6 actor_rollout_ref.actor.ppo_mini_batch_size=16 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 - actor_rollout_ref.actor.megatron.use_mbridge=True actor_rollout_ref.actor.megatron.vanilla_mbridge=False actor_rollout_ref.actor.use_dynamic_bsz=True actor_rollout_ref.actor.use_kl_loss=True diff --git a/examples/gspo_trainer/run_qwen30b_gspo.sh b/examples/gspo_trainer/run_qwen30b_gspo.sh index f4cb3309be6..7568219f840 100644 --- a/examples/gspo_trainer/run_qwen30b_gspo.sh +++ b/examples/gspo_trainer/run_qwen30b_gspo.sh @@ -82,8 +82,7 @@ ACTOR_MEGATRON_CONFIG=" +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ - +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ - actor_rollout_ref.actor.megatron.use_mbridge=True" + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True" # Actor model config ACTOR_CONFIG=" diff --git a/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh b/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh index f7d232bb6b2..d3d204136c3 100644 --- a/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh +++ b/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh @@ -141,7 +141,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ actor_rollout_ref.ref.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ diff --git a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh index 9e1d40576f0..b25c7651073 100644 --- a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh @@ -9,8 +9,6 @@ huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct # 1. convert the model to mcore format # change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct -DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct -python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH # 2. run the script @@ -95,12 +93,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - critic.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ trainer.val_before_train=False \ trainer.total_epochs=100 $@ \ No newline at end of file diff --git a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh index b82ea1d4373..ba4f6c71ad5 100644 --- a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh @@ -6,10 +6,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation ov #huggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat # 1. convert the model to mcore format -# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path HF_MODEL_PATH=/data/models/Qwen/Qwen1.5-MoE-A2.7B-Chat -DIST_CKPT_PATH=/data/mcore_ckpt/Qwen1.5-MoE-A2.7B-Chat -python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH # 2. run the script gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -41,13 +38,9 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.actor.megatron.context_parallel_size=$CP \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.ref.megatron.context_parallel_size=$CP \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ @@ -58,8 +51,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.megatron.tensor_model_parallel_size=$TP \ critic.megatron.pipeline_model_parallel_size=$PP \ critic.megatron.context_parallel_size=$CP \ - critic.megatron.use_dist_checkpointing=True \ - critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ diff --git a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh index 9a45bdaac73..2f119ec4b3d 100644 --- a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh +++ b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh @@ -81,7 +81,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_INFER_TP \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.mode=async \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ diff --git a/examples/sft/vlm/run_qwen3_vl_2b.sh b/examples/sft/vlm/run_qwen3_vl_2b.sh index 28c21ffa049..d6be0c354e2 100644 --- a/examples/sft/vlm/run_qwen3_vl_2b.sh +++ b/examples/sft/vlm/run_qwen3_vl_2b.sh @@ -60,7 +60,6 @@ MEGATRON_ENGINE_CONFIG="\ engine.pipeline_model_parallel_size=${PP_SIZE} \ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ engine.context_parallel_size=${CP_SIZE} \ - engine.use_mbridge=True \ engine.vanilla_mbridge=True" if [ "$backend" = "fsdp" ]; then diff --git a/recipe/dapo/test_dapo_8b_megatron_fp16.sh b/recipe/dapo/test_dapo_8b_megatron_fp16.sh index 0dfd77854cb..d9b78e06d65 100644 --- a/recipe/dapo/test_dapo_8b_megatron_fp16.sh +++ b/recipe/dapo/test_dapo_8b_megatron_fp16.sh @@ -121,7 +121,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.rollout.val_kwargs.do_sample=True \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.rollout.calculate_log_probs=True \ diff --git a/recipe/dapo/test_dapo_8b_megatron_fp8train.sh b/recipe/dapo/test_dapo_8b_megatron_fp8train.sh index 5827abdd879..67be3f92172 100644 --- a/recipe/dapo/test_dapo_8b_megatron_fp8train.sh +++ b/recipe/dapo/test_dapo_8b_megatron_fp8train.sh @@ -128,7 +128,6 @@ ACTOR=( actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} actor_rollout_ref.actor.entropy_coeff=0 actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} - actor_rollout_ref.actor.megatron.use_mbridge=True ) ROLLOUT=( diff --git a/recipe/dapo/test_dapo_glm_air_megatron.sh b/recipe/dapo/test_dapo_glm_air_megatron.sh index 2e7d91c07a5..64c45fcc6d1 100644 --- a/recipe/dapo/test_dapo_glm_air_megatron.sh +++ b/recipe/dapo/test_dapo_glm_air_megatron.sh @@ -91,7 +91,6 @@ RM_TP=${RM_TP:-$TRAIN_TP} RM_EP=${RM_EP:-$COMMON_EP} RM_ETP=${RM_ETP:-$COMMON_ETP} -USE_MBRIDGE=True USE_DIST_CKPT=False # Install the latest mbridge @@ -125,7 +124,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ actor_rollout_ref.actor.optim.lr_decay_style='constant' \ actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ actor_rollout_ref.actor.megatron.param_offload=${offload} \ actor_rollout_ref.actor.megatron.grad_offload=${offload} \ diff --git a/recipe/dapo/test_dapo_gptoss_20b_megatron.sh b/recipe/dapo/test_dapo_gptoss_20b_megatron.sh index 850ba26259e..c16a1964483 100644 --- a/recipe/dapo/test_dapo_gptoss_20b_megatron.sh +++ b/recipe/dapo/test_dapo_gptoss_20b_megatron.sh @@ -46,8 +46,6 @@ python get_model.py ####################### specific training config: ####################### GPT_OSS_CONFIG=( - # only support mbridge for gptoss - actor_rollout_ref.actor.megatron.use_mbridge=True # for now (latest TE=2.10), gptoss's optimized attn kernel is not supported for thd format, so we use bshd format here # when bshd format is used, we need to pad the input_ids to the longest sequence length # so we recommend to disable dynamic batch size and set micro batch size to 1 to avoid paddings diff --git a/recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh b/recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh index f5c85ca22b2..eaa769d6658 100644 --- a/recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh +++ b/recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh @@ -129,7 +129,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.rollout.val_kwargs.do_sample=True \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.rollout.calculate_log_probs=True \ diff --git a/recipe/fully_async_policy/megatron_worker.py b/recipe/fully_async_policy/megatron_worker.py index fc948ce2ea8..342d2fe267e 100644 --- a/recipe/fully_async_policy/megatron_worker.py +++ b/recipe/fully_async_policy/megatron_worker.py @@ -27,7 +27,7 @@ get_device_name, get_torch_device, ) -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker logger = logging.getLogger(__file__) @@ -113,17 +113,7 @@ def clear_cpu_model(self, n): class DetachActorWorker(DetachNcclSync): def _get_actor_params_generator(self): assert self._is_actor - if self.bridge is not None: - generator = self.bridge.export_weights(self.actor.actor_module) - else: - generator = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) - + generator = self.bridge.export_weights(self.actor.actor_module) return generator @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/recipe/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh b/recipe/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh index 5baee7917e5..e62aba98f3d 100644 --- a/recipe/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh +++ b/recipe/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh @@ -82,7 +82,6 @@ python -m recipe.fully_async_policy.fully_async_main \ +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ - actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.megatron.param_offload=True \ actor_rollout_ref.actor.megatron.optimizer_offload=True \ actor_rollout_ref.actor.megatron.grad_offload=True \ diff --git a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh index be9523f9e08..449dca1e9ee 100644 --- a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh +++ b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh @@ -91,7 +91,6 @@ RM_ETP=${RM_ETP:-$COMMON_ETP} # install mbridge # pip3 install git+https://github.com/ISEEKYAN/mbridge -USE_MBRIDGE=True USE_DIST_CKPT=False # Fully async specific parameters @@ -146,7 +145,6 @@ python -m recipe.fully_async_policy.fully_async_main \ +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ - actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ actor_rollout_ref.actor.megatron.param_offload=${offload} \ actor_rollout_ref.actor.megatron.grad_offload=${offload} \ diff --git a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh index ba2d6e4680b..2c51d4152dd 100644 --- a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh +++ b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh @@ -91,7 +91,6 @@ RM_ETP=${RM_ETP:-$COMMON_ETP} # install mbridge # pip3 install git+https://github.com/ISEEKYAN/mbridge -USE_MBRIDGE=True USE_DIST_CKPT=False # Fully async specific parameters @@ -160,7 +159,6 @@ python -m recipe.fully_async_policy.fully_async_main \ +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ - actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ actor_rollout_ref.actor.megatron.param_offload=${offload} \ actor_rollout_ref.actor.megatron.grad_offload=${offload} \ diff --git a/recipe/gkd/config/on_policy_distill_trainer.yaml b/recipe/gkd/config/on_policy_distill_trainer.yaml index 120c3b7a38b..79a96612344 100644 --- a/recipe/gkd/config/on_policy_distill_trainer.yaml +++ b/recipe/gkd/config/on_policy_distill_trainer.yaml @@ -77,7 +77,6 @@ actor_rollout_ref: seed: 42 # additional transformer config like: num_layers_in_first(/last)_pipeline_stage override_transformer_config: {} - use_mbridge: False optim: # Learning rate lr: 1e-6 diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py index c2a2407939e..0ca42f59496 100644 --- a/recipe/one_step_off_policy/megatron_workers.py +++ b/recipe/one_step_off_policy/megatron_workers.py @@ -123,21 +123,7 @@ class DetachActorWorker(DetachSync): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def _get_actor_params_generator(self): assert self._is_actor - from verl.models.mcore import get_mcore_weight_converter - from verl.utils.megatron_utils import per_tensor_generator - - layer_name_mapping = { - "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.", - } - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - generator = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - weight_converter, - self.tf_config, - layer_name_mapping, - ) + generator = self.bridge.export_weights(self.actor.actor_module) return generator @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/recipe/open_math_reasoning/run_sft_qwen3_8b.sh b/recipe/open_math_reasoning/run_sft_qwen3_8b.sh index 3b7e9bb5c6c..b5169f12fa3 100644 --- a/recipe/open_math_reasoning/run_sft_qwen3_8b.sh +++ b/recipe/open_math_reasoning/run_sft_qwen3_8b.sh @@ -54,8 +54,7 @@ MEGATRON_ENGINE_CONFIG="\ engine.tensor_model_parallel_size=${TP_SIZE} \ engine.pipeline_model_parallel_size=${PP_SIZE} \ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ - engine.context_parallel_size=${CP_SIZE} \ - engine.use_mbridge=False" + engine.context_parallel_size=${CP_SIZE} if [ "$backend" = "fsdp" ]; then ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" diff --git a/scripts/converter_hf_to_mcore.py b/scripts/converter_hf_to_mcore.py deleted file mode 100644 index 6e7cdf2b5ab..00000000000 --- a/scripts/converter_hf_to_mcore.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import warnings -from contextlib import contextmanager -from importlib.metadata import version -from typing import Any, Callable, ContextManager, Optional - -import numpy as np -import torch -import torch.distributed as dist - -try: - # NPU patch - import mindspeed.megatron_adaptor # noqa: F401 - from mindspeed.megatron_adaptor import repatch -except ImportError: - repatch = None - pass - -from accelerate import init_empty_weights -from megatron.core import dist_checkpointing -from megatron.core import parallel_state as mpu -from megatron.core.dist_checkpointing.mapping import ShardedTensor -from megatron.core.dist_checkpointing.serialization import StrictHandling -from megatron.core.models.gpt.gpt_model import ModelType -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from packaging.version import Version -from transformers import AutoConfig - -from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards -from verl.models.mcore import hf_to_mcore_config -from verl.utils.device import get_device_name, get_torch_device -from verl.utils.megatron_utils import get_model - - -def _init_args(): - """ - Examples: - - 1. single rank conversion for any model: - > python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path} - 2. distributed conversion for DeepseekV3 671B: - > torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \ - --hf_model_path %{hf_model} --output_path ${output_path} - """ - parser = argparse.ArgumentParser() - parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") - parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") - parser.add_argument("--pp_size", type=int, default=1, help="pipeline model parallel size") - parser.add_argument("--ep_size", type=int, default=1, help="expert model parallel size") - parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization") - parser.add_argument("--test", action="store_true", help="Whether to test the conversion") - parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code") - args = parser.parse_args() - return args - - -def test_conversion(megatron_model_provider, tfconfig, output_path, model): - ########### test ########### - # load model - model_test = get_model( - model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - transformer_config=tfconfig, - ) - ref_state_dict = model_test[0].module.sharded_state_dict() - dist_checkpointing.load(ref_state_dict, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED) - - dut_state_dict = model[0].module.state_dict() - for name in dut_state_dict.keys(): - if dut_state_dict[name] is None: - print(f"[Warning] {name} is none in dut_state_dict") - continue - dut_data = dut_state_dict[name].data - if name in ref_state_dict: - ref_data = ref_state_dict[name] - if isinstance(ref_data, ShardedTensor): - ref_data = ref_data.data.view(ref_data.local_shape) - else: - ref_data = ref_data.data - assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" - assert (dut_data == ref_data).all(), f"{name} is not equal" - print(f"{name} is equal") - else: - print(f"[Warning] {name} is not in ref_state_dict") - for name in ref_state_dict.keys(): - if ref_state_dict[name] is None: - print(f"[Warning] {name} is none in ref_state_dict") - continue - ref_data = ref_state_dict[name] - if isinstance(ref_data, ShardedTensor): - ref_data = ref_data.data.view(ref_data.local_shape) - else: - ref_data = ref_data.data - if name in dut_state_dict: - dut_data = dut_state_dict[name].data - assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" - assert (dut_data == ref_data).all(), f"{name} is not equal" - print(f"{name} is equal") - else: - print(f"[Warning] {name} is not in dut_state_dict") - print("Conversion test passed!") - - -@torch.inference_mode() -def convert_checkpoint_from_transformers_to_megatron( - hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None -): - if layer_start_end is None: - layer_start_end = (0, len(model.decoder.layers)) - layer_start, layer_end = layer_start_end - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - ep_rank = mpu.get_expert_model_parallel_rank() - ep_size = mpu.get_expert_model_parallel_world_size() - numel = 0 - - num_attention_heads = hf_config.num_attention_heads - num_key_value_heads = hf_config.num_key_value_heads - hidden_dim = hf_config.hidden_size - head_dim = getattr(hf_config, "head_dim", hidden_dim // num_attention_heads) - if num_attention_heads != num_key_value_heads: - print("[WARNING] Converting GQA model") - has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr(hf_config, "attention_bias", False) - has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None) - if pp_rank == 0: - numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) - - assert len(model.decoder.layers) == (layer_end - layer_start), ( - f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}" - ) - for layer_idx, (layer, hf_layer) in enumerate( - zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True) - ): - global_layer_idx = layer_idx + layer_start - numel_cur = numel - numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight) - - q = hf_layer.self_attn.q_proj.weight.view( - [num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1] - ) - k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1]) - v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1]) - qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous() - numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight) - - if has_qkv_bias: - q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1]) - k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1]) - v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1]) - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous() - numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias) - - if hasattr(hf_layer.self_attn, "q_norm"): - numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight) - numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight) - - numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight) - numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) - - numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) - - for idx, hf_expert in enumerate(hf_layer.mlp.experts): - num_experts = len(hf_layer.mlp.experts) - num_local_experts = num_experts // ep_size - expert_idx_start = ep_rank * num_local_experts - expert_idx_end = (ep_rank + 1) * num_local_experts - if idx < expert_idx_start or idx >= expert_idx_end: - continue - local_expert_idx = idx - expert_idx_start - - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{local_expert_idx}"]) - numel += safe_copy( - hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{local_expert_idx}"] - ) - - if has_share_expert: - numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight) - shared_fc1_weight = torch.cat( - [hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight] - ) - numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) - numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) - print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}") - - if pp_rank == pp_size - 1: - numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) - numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight) - return numel - - -def safe_copy( - src_tensor: torch.Tensor, - dst_tensor: torch.Tensor, - skip_dtype_assert: bool = False, -): - if not skip_dtype_assert: - if src_tensor.dtype != dst_tensor.dtype: - raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") - assert src_tensor.shape == dst_tensor.shape - dst_tensor.data.copy_(src_tensor.data) - return src_tensor.numel() - - -@torch.inference_mode() -def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config): - mgmodel = mgmodel.bfloat16() - hfmodel = hfmodel.bfloat16() - num_attention_heads = hf_config.num_attention_heads - num_query_groups = hf_config.num_key_value_heads - hidden_size = hf_config.hidden_size - head_dim = hidden_size // num_attention_heads - - # 1. vision model - if Version(version("transformers")) < Version("4.52.0"): - print("Using transformers < 4.52 API to load vision model") - hfvision = hfmodel.visual - else: - hfvision = hfmodel.model.visual - mgvision = mgmodel.vision_model - vision_hidden_size = mgvision.config.hidden_size - vision_num_query_groups = mgvision.config.num_query_groups - vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads - copied_numel = 0 - safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq) - copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight) - for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True): - # norm1 --> linear_qkv.norm - copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight) - # norm2 --> mlp.linear_fc1.norm - copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight) - # qkv --> self_attention.linear_qkv - converted_weight = ( - hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size) - .transpose(0, 1) - .flatten(1, 2) - .reshape(-1, vision_hidden_size) - .contiguous() - ) - copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight) - converted_bias = ( - hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1) - .transpose(0, 1) - .flatten(1, 2) - .view(-1) - .contiguous() - ) - copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias) - # proj --> self_attention.linear_proj - copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight) - copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias) - # mlp --> mlp: gate - fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight]) - fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias]) - copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight) - copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias) - copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight) - copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias) - - # 2. vision projector - hfprojector = hfvision.merger - mgprojector = mgvision.projection - copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight) - - copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight) - copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias) - copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight) - copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias) - n_params = sum([t.numel() for t in hfvision.state_dict().values()]) - assert n_params == copied_numel, f"n_params={n_params} != copied_numel={copied_numel}" - # 3. llm [just Qwen2] - if Version(version("transformers")) < Version("4.52.0"): - print("Using transformers < 4.52 API to load llm") - hfllm = hfmodel.model - else: - hfllm = hfmodel.model.language_model - mgllm = mgmodel.language_model - copied_numel = 0 - copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight) - layermaps = zip(mgllm.decoder.layers, hfllm.layers, strict=True) - for mglayer, hflayer in layermaps: - copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight) - - q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous() - copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight) - - q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1) - k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1) - v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1) - qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous() - copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias) - copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight) - - fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight]) - copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight) - - copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight) - copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight) - - copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight) - if not hf_config.tie_word_embeddings: - safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight) - - n_params = sum([t.numel() for t in hfllm.state_dict().values()]) - - assert n_params == copied_numel, f"n_params={n_params} != copied_numel={copied_numel}" - - -@torch.inference_mode() -def convert_checkpoint_from_transformers_to_megatron_dpskv3( - hf_model, - model, - hf_config, - tfconfig, - layer_start_end: Optional[tuple[int, int]] = None, -): - warnings.warn("MTP model is not supported yet", stacklevel=2) - if layer_start_end is None: - layer_start_end = (0, len(model.decoder.layers)) - layer_start, layer_end = layer_start_end - numel: int = 0 - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - ep_rank = mpu.get_expert_model_parallel_rank() - ep_size = mpu.get_expert_model_parallel_world_size() - - if pp_rank == 0: - numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) - - assert len(model.decoder.layers) == (layer_end - layer_start), ( - f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}" - ) - for layer_idx, (layer, hf_layer) in enumerate( - zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True) - ): - global_layer_idx = layer_idx + layer_start - numel_cur: int = numel - numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight) - - if hf_config.q_lora_rank is None: - numel += safe_copy(hf_layer.self_attn.q_proj.weight, layer.self_attention.linear_q_proj.weight) - else: - numel += safe_copy(hf_layer.self_attn.q_a_proj.weight, layer.self_attention.linear_q_down_proj.weight) - numel += safe_copy(hf_layer.self_attn.q_b_proj.weight, layer.self_attention.linear_q_up_proj.weight) - numel += safe_copy( - hf_layer.self_attn.q_a_layernorm.weight, layer.self_attention.linear_q_up_proj.layer_norm_weight - ) - - numel += safe_copy( - hf_layer.self_attn.kv_a_proj_with_mqa.weight, layer.self_attention.linear_kv_down_proj.weight - ) - numel += safe_copy(hf_layer.self_attn.kv_b_proj.weight, layer.self_attention.linear_kv_up_proj.weight) - numel += safe_copy( - hf_layer.self_attn.kv_a_layernorm.weight, layer.self_attention.linear_kv_up_proj.layer_norm_weight - ) - numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight) - - if not hasattr(layer.mlp, "router"): - numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight) - numel += safe_copy( - torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), layer.mlp.linear_fc1.weight - ) - numel += safe_copy(hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight) - else: - numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) - # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \ - # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%) - numel += safe_copy( - hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True - ) - if tfconfig.moe_grouped_gemm: - for i, hf_expert in enumerate(hf_layer.mlp.experts): - num_experts = len(hf_layer.mlp.experts) - num_local_experts = num_experts // ep_size - expert_idx_start = ep_rank * num_local_experts - expert_idx_end = (ep_rank + 1) * num_local_experts - if i < expert_idx_start or i >= expert_idx_end: - continue - local_expert_idx = i - expert_idx_start - - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(local_expert_idx)) - numel += safe_copy(fc1_weight, linear_fc1_weighti) - linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(local_expert_idx)) - numel_w2 = safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti) - numel += numel_w2 - else: - for i, hf_expert in enumerate(hf_layer.mlp.experts): - expert = layer.mlp.experts.local_experts[i] - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - numel += safe_copy(fc1_weight, expert.linear_fc1.weight) - numel += safe_copy(hf_expert.down_proj.weight, expert.linear_fc2.weight) - numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) - shared_fc1_weight = torch.cat( - [hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight] - ) - numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) - numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) - print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}") - numel_hf_one_layer = sum([i.numel() for i in hf_layer.state_dict().values()]) - if hasattr(layer.mlp, "router"): - numel_hf_one_layer -= numel_w2 * 3 * len(hf_layer.mlp.experts) // ep_size * (ep_size - 1) - assert numel - numel_cur == numel_hf_one_layer, "numel mismatch" - - if pp_rank == pp_size - 1: - numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) - if not hf_config.tie_word_embeddings: - numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight) - print(f"{pp_rank=} {numel=}") - return numel - - -@contextmanager -def noop_context() -> Any: - yield - - -def support_distributed_convert(hf_config: AutoConfig) -> bool: - for arch in ["DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", "Qwen2MoeForCausalLM"]: - if arch in hf_config.architectures: - return True - return False - - -def convert_hf_to_mcore( - hf_model_path, output_path, pp_size=1, ep_size=1, use_cpu_initialization=False, test=False, trust_remote_code=False -): - os.makedirs(output_path, exist_ok=True) - if len(os.listdir(output_path)) > 0 and not test: - print(f"Output path {output_path} is not empty, skipping conversion") - return - - # init torch distributed and mpu - if "WORLD_SIZE" not in os.environ: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - torch.distributed.init_process_group("nccl") - - local_rank = os.getenv("LOCAL_RANK", 0) - world_size = dist.get_world_size() - get_torch_device().set_device(f"{get_device_name()}:{local_rank}") - if ep_size * pp_size != world_size: - pp_size = world_size - print(f"pp_size is set to {pp_size}") - - mpu.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=pp_size, - virtual_pipeline_model_parallel_size=None, - context_parallel_size=1, - expert_model_parallel_size=ep_size, - ) - model_parallel_cuda_manual_seed(0) - - # init hf config - hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=trust_remote_code) - print(hf_config, flush=True) - - if repatch: - if hf_config.architectures[0] == "DeepseekV3ForCausalLM": - config_repatch = dict(multi_head_latent_attention=True) - repatch(config_repatch) - - if world_size > 1 and not support_distributed_convert(hf_config): - raise NotImplementedError(f"distributed conversion is not supported for {hf_config.architectures} yet.") - - pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, pp_size) - print(f"Pipeline shards: {pipeline_shards}", flush=True) - - tfconfig = hf_to_mcore_config( - hf_config, - torch.bfloat16, - num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None, - num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None, - ) - tfconfig.use_cpu_initialization = use_cpu_initialization - tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False) - - # init megatron model - def megatron_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - tfconfig, - hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=tie_word_embeddings, - value=False, - ) - return parallel_model - - context: Callable[..., ContextManager] = init_empty_weights if use_cpu_initialization else noop_context - with context(): - model = get_model( - model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - transformer_config=tfconfig, - ) - - if use_cpu_initialization: - # convert meta device to empty tensor so it can use `copy_` function - model[0].module = model[0].module.to_empty(device="cpu") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from transformers import AutoModelForCausalLM, AutoModelForImageTextToText - - # init hf model - if "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: - hf_model = AutoModelForImageTextToText.from_pretrained( - hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code - ) - else: - hf_model = AutoModelForCausalLM.from_pretrained( - hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code - ) - hf_state_dict = hf_model.state_dict() - - pp_rank = mpu.get_pipeline_model_parallel_rank() - - # distributed convert - if world_size > 1 and support_distributed_convert(hf_config): - pipeline_cumsum = np.cumsum(pipeline_shards) - layer_start = 0 if pp_rank == 0 else pipeline_cumsum[pp_rank - 1] - layer_end = pipeline_cumsum[pp_rank] - if "DeepseekV3ForCausalLM" in hf_config.architectures: - numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3( - hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end) - ) - elif "Qwen3MoeForCausalLM" in hf_config.architectures or "Qwen2MoeForCausalLM" in hf_config.architectures: - numel_partial: int = convert_checkpoint_from_transformers_to_megatron( - hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end) - ) - else: - raise NotImplementedError(f"Distributed conversion is not supported for {hf_config.architectures} yet.") - - numel_tensor = torch.tensor([numel_partial]).to(get_device_name()) - dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM) - numel = int(numel_tensor.cpu().item()) - print(f"total numel={numel} vs {hf_model.num_parameters()=}") - if numel != hf_model.num_parameters(): - warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1) - - # load hf state dict to megatron model - elif "Qwen2MoeForCausalLM" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) - elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config) - elif "DeepseekV3ForCausalLM" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig) - elif "Qwen3MoeForCausalLM" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) - else: - assert not use_cpu_initialization, "use_cpu_initialization is only supported for MoE model" - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - - load_state_dict_to_megatron_gptmodel( - state_dict=hf_state_dict, - wrapped_models=model, - config=hf_config, - params_dtype=torch.bfloat16, - is_value_model=False, - ) - - megatron_state_dict = model[0].module.sharded_state_dict() - del hf_state_dict, hf_model - - # save megatron model - if len(os.listdir(output_path)) == 0: - dist_checkpointing.save(megatron_state_dict, output_path, sharded_strategy=None, async_sharded_save=False) - if test: - test_conversion(megatron_model_provider, tfconfig, output_path, model) - - -if __name__ == "__main__": - args = _init_args() - convert_hf_to_mcore( - args.hf_model_path, - args.output_path, - args.pp_size, - args.ep_size, - args.use_cpu_initialization, - args.test, - args.trust_remote_code, - ) diff --git a/scripts/legacy_model_merger.py b/scripts/legacy_model_merger.py deleted file mode 100644 index a6da5072df0..00000000000 --- a/scripts/legacy_model_merger.py +++ /dev/null @@ -1,804 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. - -To merge FSDP checkpoints: -```sh -python scripts/legacy_model_merger.py merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -To merge Megatron checkpoints: -```sh -python scripts/legacy_model_merger.py merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -For more details, please refer to documentation: -https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model -""" - -import argparse -import os -import re -import warnings -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional, Union - -import numpy as np -import torch -from accelerate import init_empty_weights -from safetensors.torch import load_file -from torch.distributed._tensor import Placement, Shard -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForTokenClassification, - AutoModelForVision2Seq, - GenerationConfig, - PretrainedConfig, -) - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from tqdm import tqdm - -from verl.utils import hf_processor, hf_tokenizer - - -@dataclass -class ModelMergerConfig: - operation: str # 'merge' or 'test' - backend: str - local_dir: str - hf_model_config_path: str - target_dir: Optional[str] = "tmp" - hf_upload_path: Optional[str] = None - private: bool = False - test_hf_dir: Optional[str] = None - tie_word_embedding: bool = False - is_value_model: bool = False - hf_model_path: Optional[str] = None - hf_upload: bool = field(init=False) - - def __post_init__(self): - self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) - if self.operation == "test": - self.target_dir = None - self.hf_upload_path = None - self.private = False - - -class BaseModelMerger(ABC): - def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - - if config.hf_model_path: - print( - "Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. " - ) - self.hf_model_config_path = config.hf_model_path - - # Auto-detect huggingface subdirectory if it exists - huggingface_subdir = os.path.join(self.hf_model_config_path, "huggingface") - if os.path.isdir(huggingface_subdir): - self.hf_model_config_path = huggingface_subdir - - self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) - - def get_transformers_auto_model_class(self): - # Handle case where architectures might be None or empty - if self.model_config.architectures is None or len(self.model_config.architectures) == 0: - # Try to infer from model_type if architectures is missing - model_type = getattr(self.model_config, 'model_type', '').lower() - if 'vision' in model_type or 'vl' in model_type: - return AutoModelForVision2Seq - elif 'causal' in model_type or 'gpt' in model_type or 'llama' in model_type or 'qwen' in model_type: - return AutoModelForCausalLM - else: - raise NotImplementedError( - f"Cannot determine model class: architectures is None and model_type '{model_type}' is not recognized" - ) - - architecture = self.model_config.architectures[0] - if "ForTokenClassification" in architecture: - return AutoModelForTokenClassification - elif "ForCausalLM" in architecture: - return AutoModelForCausalLM - elif "ForConditionalGeneration" in architecture: - return AutoModelForVision2Seq - - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") - - def patch_model_generation_config(self, model): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) - except OSError: - print( - f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config." - ) - return model - - def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): - """ - Save lora adapter to safetensors. - - Returns: - lora_path: str, the path to the lora adapter. None if no lora adapter found. - - Note: - This function change the 'state_dict' in place. - """ - lora_params_names = [name for name in state_dict.keys() if "lora_" in name] - - if len(lora_params_names) == 0: - return None - - import json - from typing import OrderedDict - - import peft - from safetensors.torch import save_file - - lora_params = OrderedDict() - target_modules = set() - lora_key = None - - for name in lora_params_names: - lora_key = name.replace(".default.weight", ".weight") - target_modules.add(lora_key.split(".")[-3]) - lora_params[lora_key] = state_dict.pop(name) - - lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) - peft_dict = { - "r": lora_rank, - "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. - "target_modules": list(target_modules), - } - peft_config = peft.LoraConfig(**peft_dict).to_dict() - peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None - peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None - peft_config["target_modules"] = list(peft_config["target_modules"]) - - lora_path = os.path.join(self.config.target_dir, "lora_adapter") - os.makedirs(lora_path, exist_ok=True) - with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(peft_config, f, ensure_ascii=False, indent=4) - save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) - - for name in list(state_dict.keys()): - key = ( - name.replace("base_model.model.", "") - .replace(".base_layer.weight", ".weight") - .replace(".base_layer.bias", ".bias") - ) - state_dict[key] = state_dict.pop(name) - - return lora_path - - def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = self.patch_model_generation_config(model) - - lora_path = self.save_lora_adapter(state_dict) - if lora_path: - print(f"Saving lora adapter to {lora_path}") - - print(f"Saving model to {self.config.target_dir}") - model.save_pretrained(self.config.target_dir, state_dict=state_dict) - del state_dict - del model - - processor = hf_processor(self.hf_model_config_path) - try: - tokenizer = hf_tokenizer(self.hf_model_config_path) - except Exception as e: - warnings.warn(f"Failed to create tokenizer: {e}. This may affect tokenizer saving", stacklevel=1) - tokenizer = None - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def upload_to_huggingface(self): - from huggingface_hub import HfApi - - api = HfApi() - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") - - @abstractmethod - def merge_and_save(self): - raise NotImplementedError("Subclasses should implement this method") - - -class FSDPModelMerger(BaseModelMerger): - def _get_world_size(self) -> int: - """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" - for filename in os.listdir(self.config.local_dir): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - return int(match.group(1)) - raise FileNotFoundError( - f"Could not determine world size. No file matching 'model_world_size_(\\d+)_rank_0.pt' found in {self.config.local_dir}" - ) - - def _load_rank_zero_state_dict(self, world_size: int) -> dict: - return torch.load( - Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", - map_location="cpu", - weights_only=False, - ) - - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: - """ - Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. - If no DTensor is found, infers a simple FSDP mesh based on world_size. - """ - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([world_size], dtype=np.int64) - mesh_dim_names = ("fsdp",) - - return mesh, mesh_dim_names - - def _calculate_shard_configuration( - self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] - ) -> tuple[int, tuple[int, ...]]: - """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - - if "tp" in mesh_dim_names: - # TODO: "tp" is not supported yet due to the above assert - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) - - return total_shards, mesh_shape - - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: - """Merges a list of tensors based on their DTensor placement""" - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - - raise NotImplementedError(f"Unsupported placement: {placement}") - - def _load_and_merge_state_dicts( - self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] - ) -> dict[str, torch.Tensor]: - model_state_dict_lst = [None] * total_shards - - def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): - future.result() - - # Merge state dicts from all shards - state_dict = {} - param_placements: dict[str, list] = {} - - for key in set(model_state_dict_lst[0].keys()): - state_dict[key] = [] - for model_state_shard in model_state_dict_lst: - # add tensor shard in order of rank to state_dict[key] - tensor = model_state_shard.pop(key) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] in ("dp", "ddp"): - placements = placements[1:] - - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - # Merge tensors - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = self._merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - else: - state_dict[key] = torch.cat(state_dict[key], dim=0) - - return state_dict - - def merge_and_save(self): - world_size = self._get_world_size() - rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") - - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) - hf_state_dict = hf_model.state_dict() - del hf_model - - hf_model_keys = set(hf_state_dict.keys()) - collected_keys = set(state_dict.keys()) - - missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in hf_model_keys: - hf_shape = hf_state_dict[key].shape - collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, ( - f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - ) - - hf_dtype = hf_state_dict[key].dtype - collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, ( - f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - ) - - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - - -class MegatronModelMerger(BaseModelMerger): - def __init__(self, config: ModelMergerConfig): - from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path - - config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) - super().__init__(config) - - self.params_mapping = { - # megatron core gpt model name, huggingface model name - # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the longer key within the containing relationship is processed first. - "embedding.word_embeddings": "model.embed_tokens", - # attn - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", - "self_attention.linear_qkv": "self_attn.qkv_proj", - "self_attention.q_layernorm": "self_attn.q_norm", - "self_attention.k_layernorm": "self_attn.k_norm", - "self_attention.linear_proj": "self_attn.o_proj", - # mla - "self_attention.linear_q_proj": "self_attn.q_proj", - "self_attention.linear_q_down_proj": "self_attn.q_a_proj", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - "self_attention.linear_q_up_proj": "self_attn.q_b_proj", - "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", - # mlp - "pre_mlp_layernorm": "post_attention_layernorm", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", - "mlp.linear_fc1": "mlp.gate_up_proj", - "mlp.linear_fc2": "mlp.down_proj", - # moe - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - "mlp.router": "mlp.gate", - "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", - "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", - "linear_fc1": "gate_up_proj", - "linear_fc2": "down_proj", - # output - "final_layernorm": "norm", - "output_layer": "lm_head", - } - - def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: - tp_rank = pp_rank = None - rank_list = sharded_dir.split("_")[2:] - if re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir): - tp_rank = int(rank_list[0]) - pp_rank = int(rank_list[1]) - elif re.match(r"mp_rank_(\d\d)", sharded_dir): - tp_rank = int(rank_list[0]) - pp_rank = 0 - - assert tp_rank is not None and pp_rank is not None, f"Invalid sharded dir {sharded_dir}" - - return tp_rank, pp_rank - - def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: - """ - Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). - Determines TP and PP sizes from directory names. - """ - tp_size = 0 - pp_size = 0 - sharded_dirs = sorted(os.listdir(model_path)) - for sharded_dir in sharded_dirs: - assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - tp_size = max(tp_size, tp_rank + 1) - pp_size = max(pp_size, pp_rank + 1) - return sharded_dirs, tp_size, pp_size - - def _merge_across_tp( - self, - key: str, - tp_data: list[torch.Tensor], - config: PretrainedConfig, - tp_size: int, - is_value_model: bool = False, - ) -> Union[torch.Tensor, list[torch.Tensor]]: - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in tp_data: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 - kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - - for infer_param in tp_data: - num_query_groups_per_partition = config.num_key_value_heads // tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - return [q, k, v] - elif "layer_norm" in key or "layernorm" in key or "router" in key or ("output_layer" in key and is_value_model): - return tp_data[0] - else: - dim = 0 - if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: - dim = 1 - return torch.cat(tp_data, dim=dim) - - def _load_state_dicts( - self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int - ) -> list[list[dict]]: - model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] - - def _process_one_megatron_shard(sharded_dir: str): - model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" - state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - model_state_dict_lst[pp_rank][tp_rank] = state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] - for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): - future.result() - - return model_state_dict_lst - - def _check_megatron_state_key(self, key: str) -> bool: - """ - Checks if the key is a valid Megatron state key. - - Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. - Shall not use key starts with "model." - """ - if key.startswith("model."): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer." - ) - - skip_checking_keys = ["embedding.word_embeddings", "output_layer"] - for skip_key in skip_checking_keys: - if skip_key in key: - print(f"skip checking key {key}") - return - - # Exclude extra state keys - if not key.startswith("decoder"): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." - ) - - def _merge_state_dicts( - self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int - ) -> dict[str, torch.Tensor]: - state_dict = {} - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - self._check_megatron_state_key(key) - hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." - if "model.layers." in hf_name: - local_layer_no = int(hf_name.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = hf_name.split(".") - new_key_list[2] = str(global_layer_no) - hf_name = ".".join(new_key_list) - else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model) - - if not isinstance(merged, list): - state_dict[hf_name] = merged - elif len(merged) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], merged): - state_dict[hf_name.replace("qkv", n)] = d - elif len(merged) == 2: - # split gate up - state_dict[hf_name.replace("gate_up", "gate")] = merged[0] - state_dict[hf_name.replace("gate_up", "up")] = merged[1] - print( - f"converted {key} to {hf_name} with shape {merged.shape if isinstance(merged, torch.Tensor) else [t.shape for t in merged]}" - ) - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def merge_and_save(self): - from verl.utils.megatron_utils import get_model_checkpoint_path - - model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) - sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") - - model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) - merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) - del model_state_dict_lst - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - for name, loaded_weight in state_dict.items(): - # name = self._replace_name(original_name, self.params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embedding and "lm_head.weight" in name: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: - for m_name, v_name in name_mapping.items(): - if m_name not in megatron_name: - continue - - megatron_name = megatron_name.replace("decoder", "model") - param_name = megatron_name.replace(m_name, v_name) - return param_name - - return None # Return None if no mapping found - - -def main(): - parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") - - base_op_parser = argparse.ArgumentParser(add_help=False) - base_op_parser.add_argument( - "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" - ) - base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") - base_op_parser.add_argument( - "--hf_model_path", - type=str, - default=None, - help="(Deprecated) Path to the original Hugging Face model for config.", - ) - base_op_parser.add_argument( - "--tie-word-embedding", - action="store_true", - help="Whether to tie word embedding weights (currently only Megatron supported)", - ) - base_op_parser.add_argument( - "--is-value-model", - action="store_true", - help="Whether the model is a value model (currently only Megatron supported)", - ) - - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") - merge_parser.add_argument( - "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" - ) - merge_parser.add_argument( - "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" - ) - merge_parser.add_argument( - "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" - ) - - test_parser = subparsers.add_parser( - "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" - ) - test_parser.add_argument( - "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" - ) - - args = parser.parse_args() - - common_config_args = { - "operation": args.operation, - "backend": args.backend, - "tie_word_embedding": args.tie_word_embedding, - "is_value_model": args.is_value_model, - "local_dir": args.local_dir, - "hf_model_path": args.hf_model_path, - "hf_model_config_path": args.local_dir, - } - - if args.operation == "merge": - config = ModelMergerConfig( - **common_config_args, - target_dir=args.target_dir, - hf_upload_path=args.hf_upload_path, - private=args.private, - test_hf_dir=None, - ) - os.makedirs(config.target_dir, exist_ok=True) - elif args.operation == "test": - config = ModelMergerConfig( - **common_config_args, - test_hf_dir=args.test_hf_dir, - # the following args are not used by test operation - target_dir=None, - hf_upload_path=None, - private=False, - ) - else: - raise NotImplementedError(f"Unknown operation: {args.operation}") - - if config.backend == "fsdp": - merger = FSDPModelMerger(config) - elif config.backend == "megatron": - merger = MegatronModelMerger(config) - else: - raise NotImplementedError(f"Unknown backend: {config.backend}") - - merger.merge_and_save() - - -if __name__ == "__main__": - main() diff --git a/tests/models/test_engine.py b/tests/models/test_engine.py index bb3433a3f33..43d1ed507f2 100644 --- a/tests/models/test_engine.py +++ b/tests/models/test_engine.py @@ -73,7 +73,6 @@ def test_engine(strategy): if strategy == "megatron": engine_config = McoreEngineConfig( forward_only=False, - use_mbridge=False, tensor_model_parallel_size=2, pipeline_model_parallel_size=2, context_parallel_size=2, @@ -230,7 +229,6 @@ def test_critic_engine(strategy): if strategy == "megatron": engine_config = McoreEngineConfig( forward_only=False, - use_mbridge=False, tensor_model_parallel_size=2, pipeline_model_parallel_size=2, context_parallel_size=2, @@ -353,7 +351,6 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod if strategy == "megatron": engine_config = McoreEngineConfig( forward_only=False, - use_mbridge=True, tensor_model_parallel_size=2, pipeline_model_parallel_size=2, context_parallel_size=1, diff --git a/tests/special_distributed/test_mcore_config_converter.py b/tests/special_distributed/test_mcore_config_converter.py deleted file mode 100644 index d8f24c49911..00000000000 --- a/tests/special_distributed/test_mcore_config_converter.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import megatron.core.parallel_state as mpu -import torch -from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from transformers import AutoConfig, PretrainedConfig - -from verl.models.mcore import hf_to_mcore_config -from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group - -TEST_MODELS = [ - "Qwen/Qwen2.5-7B", # Qwen2 dense - "Qwen/Qwen3-8B", # Qwen3 dense - "deepseek-ai/deepseek-coder-1.3b-instruct", # deepseek dense - "Qwen/Qwen2-57B-A14B", # Qwen2 moe - "Qwen/Qwen3-30B-A3B", # Qwen3 moe - # "mistralai/Mixtral-8x7B-v0.1", # Mixtral # require authentication - "deepseek-ai/DeepSeek-V3-Base", # Deepseek V3 -] - - -def check_config_converter_results(tf_config: TransformerConfig | MLATransformerConfig, hf_config: PretrainedConfig): - assert tf_config.num_layers == hf_config.num_hidden_layers, ( - f"Number of layers mismatch: {tf_config.num_layers} != {hf_config.num_hidden_layers}" - ) - assert tf_config.hidden_size == hf_config.hidden_size, ( - f"Hidden size mismatch: {tf_config.hidden_size} != {hf_config.hidden_size}" - ) - assert tf_config.num_attention_heads == hf_config.num_attention_heads, ( - f"Number of attention heads mismatch: {tf_config.num_attention_heads} != {hf_config.num_attention_heads}" - ) - assert tf_config.num_query_groups == hf_config.num_key_value_heads, ( - f"Number of query groups mismatch: {tf_config.num_query_groups} != {hf_config.num_key_value_heads}" - ) - assert tf_config.ffn_hidden_size == hf_config.intermediate_size, ( - f"FFN hidden size mismatch: {tf_config.ffn_hidden_size} != {hf_config.intermediate_size}" - ) - assert tf_config.attention_dropout == hf_config.attention_dropout, ( - f"Attention dropout mismatch: {tf_config.attention_dropout} != {hf_config.attention_dropout}" - ) - assert tf_config.hidden_dropout == getattr(hf_config, "hidden_dropout", 0.0), ( - f"Hidden dropout mismatch: {tf_config.hidden_dropout} != {getattr(hf_config, 'hidden_dropout', 0.0)}" - ) - if getattr(hf_config, "head_dim", None) is not None: - assert tf_config.kv_channels == getattr(hf_config, "head_dim", None), ( - f"Head dim mismatch: {tf_config.kv_channels} != {getattr(hf_config, 'head_dim', None)}" - ) - assert tf_config.layernorm_epsilon == hf_config.rms_norm_eps, ( - f"Layernorm epsilon mismatch: {tf_config.layernorm_epsilon} != {hf_config.rms_norm_eps}" - ) - - -def modify_hf_config(name: str, hf_config: PretrainedConfig): - if name == "deepseek-ai/DeepSeek-V3-Base": - hf_config.num_nextn_predict_layers = 0 - hf_config.quantization_config = None - return hf_config - - -def test_mcore_config_converter(): - """ - Test the conversion of Hugging Face model configurations to MCore configurations. - """ - local_rank, rank, world_size = initialize_global_process_group() - mpu.initialize_model_parallel( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - virtual_pipeline_model_parallel_size=None, - use_sharp=False, - context_parallel_size=2, - expert_model_parallel_size=1, - expert_tensor_parallel_size=None, - nccl_communicator_config_path=None, - ) - for model_name in TEST_MODELS: - print(f"testing {model_name}") - hf_config = AutoConfig.from_pretrained(os.path.expanduser(f"~/models/configs/{model_name}/config.json")) - hf_config = modify_hf_config(model_name, hf_config) - tf_config = hf_to_mcore_config(hf_config, torch.bfloat16) - check_config_converter_results(tf_config, hf_config) - - destroy_global_process_group() - - -if __name__ == "__main__": - test_mcore_config_converter() diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh index a88500aba40..dca6cd320b4 100644 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -107,7 +107,6 @@ CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -USE_MBRIDGE=${USE_MBRIDGE:-False} VANILLA_MBRIDGE=${VANILLA_MBRIDGE:-True} VALUE_VANILLA_MBRIDGE=${VALUE_VANILLA_MBRIDGE:-$VANILLA_MBRIDGE} USE_FUSED_KERNELS=${USE_FUSED_KERNELS:-False} @@ -126,9 +125,6 @@ if [ "$USE_DIST_CKPT" = "True" ]; then if [ "$USE_DUMMY_MODEL" = "True" ]; then DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID} fi - python scripts/converter_hf_to_mcore.py \ - --hf_model_path "${MODEL_PATH}" \ - --output_path "${DIST_CKPT_PATH}" fi ENGINE=${ENGINE:-"vllm"} @@ -175,7 +171,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ - actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \ actor_rollout_ref.actor.megatron.vanilla_mbridge=${VANILLA_MBRIDGE} \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ @@ -204,7 +199,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ ++actor_rollout_ref.rollout.quantization=${ROLLOUT_QUANTIZATION} \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \ actor_rollout_ref.ref.megatron.vanilla_mbridge=${VANILLA_MBRIDGE} \ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ @@ -226,7 +220,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ critic.model.lora.target_modules=${LORA_TARGET_MODULES} \ critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ - critic.megatron.use_mbridge=${USE_MBRIDGE} \ critic.megatron.vanilla_mbridge=${VALUE_VANILLA_MBRIDGE} \ critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ @@ -246,7 +239,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ reward_model.enable=True \ reward_model.model.path="${MODEL_PATH}" \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - reward_model.megatron.use_mbridge=${USE_MBRIDGE} \ reward_model.megatron.vanilla_mbridge=${VALUE_VANILLA_MBRIDGE} \ reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ diff --git a/tests/special_npu/run_qwen2_5_05b_grpo_mindspeed.sh b/tests/special_npu/run_qwen2_5_05b_grpo_mindspeed.sh index bdf225dc3a1..f00e661205f 100644 --- a/tests/special_npu/run_qwen2_5_05b_grpo_mindspeed.sh +++ b/tests/special_npu/run_qwen2_5_05b_grpo_mindspeed.sh @@ -10,9 +10,6 @@ if [ "$USE_DIST_CKPT" = "True" ]; then if [ "$USE_DUMMY_MODEL" = "True" ]; then DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID} fi - python scripts/converter_hf_to_mcore.py \ - --hf_model_path "${MODEL_PATH}" \ - --output_path "${DIST_CKPT_PATH}" fi diff --git a/tests/special_npu/run_qwen3_30b_dapo_mindspeed.sh b/tests/special_npu/run_qwen3_30b_dapo_mindspeed.sh index aece3d11471..346ce37a38a 100644 --- a/tests/special_npu/run_qwen3_30b_dapo_mindspeed.sh +++ b/tests/special_npu/run_qwen3_30b_dapo_mindspeed.sh @@ -42,9 +42,6 @@ if [[ "$USE_DIST_CKPT" == "True" ]]; then fi fi - torchrun --nproc_per_node 2 --nnodes 1 scripts/converter_hf_to_mcore.py \ - --hf_model_path "${MODEL_PATH}" \ - --output_path "${DIST_CKPT_PATH}" fi exp_name='Qwen3-30B-A3B-DAPO-MindSpeed' diff --git a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml index ea2a15d685e..5376a2915d7 100644 --- a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml @@ -111,7 +111,6 @@ actor_rollout_ref: dist_checkpointing_path: null seed: 42 override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - use_mbridge: False vanilla_mbridge: True profile: # profile the actor model in `update_policy` use_profile: False # open it when you want to profile the actor model @@ -144,7 +143,6 @@ actor_rollout_ref: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge} profile: use_profile: False @@ -315,7 +313,6 @@ critic: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge} load_weight: True ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} @@ -360,7 +357,6 @@ reward_model: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: {} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge} model: input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical diff --git a/tests/utils/megatron/test_pipeline_parallel.py b/tests/utils/megatron/test_pipeline_parallel.py deleted file mode 100644 index 24a416987da..00000000000 --- a/tests/utils/megatron/test_pipeline_parallel.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards -from verl.utils.megatron.pipeline_parallel import make_batch_generator - - -def test_make_batch_generator_no_vpp(): - batches = [1, 2, 3] - vpp_size = 1 - generator = make_batch_generator(batches, vpp_size) - assert list(generator) == batches - - -def test_make_batch_generator_with_vpp(): - batches = [{"data": 1}, {"data": 2}] - vpp_size = 2 - generators = make_batch_generator(batches, vpp_size) - assert isinstance(generators, list) - assert len(generators) == vpp_size - - # Check each generator yields the original batches - for gen in generators: - assert list(gen) == batches - - -def test_make_batch_generator_empty(): - batches = [] - vpp_size = 1 - generator = make_batch_generator(batches, vpp_size) - assert list(generator) == [] - - vpp_size = 3 - generators = make_batch_generator(batches, vpp_size) - assert len(generators) == vpp_size - for gen in generators: - assert list(gen) == [] - - -@pytest.mark.parametrize( - "layer_num,pp_size,gt", - [ - (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]), - (61, 7, [8, 9, 9, 9, 9, 9, 8]), - (61, 1, [61]), - (61, 0, ValueError), - (10, 16, ValueError), - ], -) -def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt): - if isinstance(gt, list): - shards = get_dynamic_pipeline_shards(layer_num, pp_size) - assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}" - assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}" - elif issubclass(gt, Exception): - with pytest.raises(gt): - shards = get_dynamic_pipeline_shards(layer_num, pp_size) diff --git a/verl/model_merger/__main__.py b/verl/model_merger/__main__.py index f3ab5b9c29b..c801a82cb2e 100644 --- a/verl/model_merger/__main__.py +++ b/verl/model_merger/__main__.py @@ -58,10 +58,6 @@ def main(): from .fsdp_model_merger import FSDPModelMerger merger = FSDPModelMerger(config) - elif config.backend == "megatron": - from .megatron_model_merger import MegatronModelMerger - - merger = MegatronModelMerger(config) else: raise NotImplementedError(f"Unknown backend: {config.backend}") diff --git a/verl/model_merger/megatron_model_merger.py b/verl/model_merger/megatron_model_merger.py deleted file mode 100644 index bccd54d2ab1..00000000000 --- a/verl/model_merger/megatron_model_merger.py +++ /dev/null @@ -1,546 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import warnings -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Callable, ContextManager - -import numpy as np -import torch -import torch.distributed as dist - -try: - # NPU patch - import mindspeed.megatron_adaptor # noqa: F401 -except ImportError: - pass - -from accelerate import init_empty_weights -from megatron.core import mpu -from megatron.core.models.gpt.gpt_model import ModelType -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from safetensors.torch import load_file -from transformers import ( - AutoConfig, - PretrainedConfig, -) - -from verl.models.mcore import hf_to_mcore_config -from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device -from verl.utils.distributed import set_numa_affinity -from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing -from verl.utils.megatron_utils import get_model -from verl.utils.tokenizer import hf_processor, hf_tokenizer - -from .base_model_merger import BaseModelMerger, ModelMergerConfig - - -@contextmanager -def noop_context() -> Any: - yield - - -def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]: - """Calculate the pipeline sharding configuration for Megatron-LM. - - Args: - layer_num: Total number of layers in the model. - pp_size: Number of pipeline parallel ranks. - - Returns: - layer number of each pp rank. Make the sharding of the pipeline as uniform as possible. - """ - if layer_num < pp_size: - raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.") - - if pp_size < 1: - raise ValueError(f"pp_size must be at least 1, got {pp_size}.") - if pp_size == 1: - return [layer_num] - - if pp_size == 2: - return [ - layer_num // 2, - layer_num - layer_num // 2, - ] - - middle_size = pp_size - 2 - shards_strategy = [] - for middle_layer_num in range(layer_num): - first_last_layer_num = layer_num - middle_layer_num * middle_size - first_layer_num = first_last_layer_num // 2 - last_layer_num = first_last_layer_num - first_last_layer_num // 2 - if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num: - shards_strategy.append( - ( - [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num], - abs(first_layer_num - middle_layer_num), - ) - ) - - # sort by diff of layer_num, to make it as uniform as possible - res = sorted(shards_strategy, key=lambda x: x[1])[0][0] - assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}" - return res - - -class MegatronModelMerger(BaseModelMerger): - """ - Model merger for Megatron-LM distributed checkpoints. - - This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format. - Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute - large language models across multiple GPUs. This merger reconstructs the full model by - loading distributed checkpoints and applying the necessary transformations. - - Key features: - - Support for tensor parallel, pipeline parallel, and data parallel configurations - - Automatic parameter name mapping from Megatron to HuggingFace conventions - - Handling of QKV and gate-up tensor splitting/merging - - Support for tied word embeddings and value models - - Integration with Megatron's distributed checkpointing system - - The merger handles various model architectures and configurations: - - Standard transformer models (GPT-style) - - Models with tied word embeddings - - Value models for reinforcement learning - - Multi-layer attention (MLA) architectures - - Mixture of Experts (MoE) models - - Args: - config (ModelMergerConfig): Configuration object with Megatron-specific settings - including tie_word_embedding and is_value_model flags. - - Example: - To merge Megatron checkpoints: - ```python - config = ModelMergerConfig( - operation="merge", - backend="megatron", - local_dir="path/to/megatron/checkpoints", - target_dir="path/to/output", - tie_word_embedding=True - ) - merger = MegatronModelMerger(config) - merger.merge_and_save() - ``` - """ - - def __init__(self, config: ModelMergerConfig): - super().__init__(config) - # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards - if "WORLD_SIZE" not in os.environ: - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - set_numa_affinity() - torch.distributed.init_process_group(get_nccl_backend()) - - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - local_rank = os.environ.get("LOCAL_RANK", 0) - get_torch_device().set_device(f"{get_device_name()}:{local_rank}") - - mpu.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=self.world_size, - virtual_pipeline_model_parallel_size=None, - context_parallel_size=1, - expert_model_parallel_size=1, - ) - model_parallel_cuda_manual_seed(0) - self.hf_config = AutoConfig.from_pretrained( - self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code - ) - print(self.hf_config, flush=True) - - self.params_mapping = { - # megatron core gpt model name, huggingface model name - # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the - # longer key within the containing relationship is processed first. - "embedding.word_embeddings": "model.embed_tokens", - # input layer norm for dpskv3 - "input_layernorm.weight": "input_layernorm.weight", - "input_layernorm.bias": "input_layernorm.bias", - # attn - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", - "self_attention.linear_qkv": "self_attn.qkv_proj", - "self_attention.q_layernorm": "self_attn.q_norm", - "self_attention.k_layernorm": "self_attn.k_norm", - "self_attention.linear_proj": "self_attn.o_proj", - # mla - "self_attention.linear_q_proj": "self_attn.q_proj", - "self_attention.linear_q_down_proj": "self_attn.q_a_proj", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - "self_attention.linear_q_up_proj": "self_attn.q_b_proj", - "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", - # mlp - "pre_mlp_layernorm": "post_attention_layernorm", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", - "mlp.linear_fc1": "mlp.gate_up_proj", - "mlp.linear_fc2": "mlp.down_proj", - # moe - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - "mlp.router": "mlp.gate", - "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", - "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", - "linear_fc1": "gate_up_proj", - "linear_fc2": "down_proj", - # output - "final_layernorm": "norm", - "output_layer": "lm_head", - } - - if "Qwen2MoeForCausalLM" in self.hf_config.architectures: - self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj" - self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" - self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight" - - def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]: - """_summary_ - Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory. - - Args: - model_ckpt_path (str): Path to the model checkpoint directory. - - Returns: - State dict containing the model parameters. - """ - - # init hf config - self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size) - print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}") - - tf_config = hf_to_mcore_config( - self.hf_config, - torch.bfloat16, - num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None, - num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None, - ) - tf_config.use_cpu_initialization = self.config.use_cpu_initialization - tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False) - - # init megatron model - def megatron_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=tie_word_embeddings, - value=False, - ) - return parallel_model - - context: Callable[..., ContextManager] = ( - init_empty_weights if self.config.use_cpu_initialization else noop_context - ) - with context(): - whole_model = get_model( - model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - transformer_config=tf_config, - ) - - if self.config.use_cpu_initialization: - # convert meta device to empty tensor so it can use `copy_` function - whole_model[0].module = whole_model[0].module.to_empty(device="cpu") - - # load state dicts - sharded_state_dict = {} - for vpp_rank, model in enumerate(whole_model): - key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - sharded_state_dict[key] = model.sharded_state_dict() - model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path) - model_state_dict_list = [] - for vpp_rank, model in enumerate(whole_model): - key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - model_state_dict_list.append(model_state_dict[key]) - - return model_state_dict_list - - def _check_megatron_state_key(self, key: str) -> bool: - """ - Checks if the key is a valid Megatron state key. - - Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. - Shall not use key starts with "model." - """ - if key.startswith("model."): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with " - f"'decoder/embedding/output_layer' in TransformerLayer." - ) - - skip_checking_keys = ["embedding.word_embeddings", "output_layer"] - for skip_key in skip_checking_keys: - if skip_key in key: - print(f"skip checking key {key}") - return - - # Exclude extra state keys - if not key.startswith("decoder"): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." - ) - - def _split_tensors( - self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False - ) -> list[torch.Tensor]: - """ - Splits a tensor into multiple tensors based on the name. - This is used to handle qkv and gate_up tensors. - """ - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - gate, up = tensor.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst, k_lst, v_lst = [], [], [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( - f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" - ) - kv_size = tensor.shape[0] // (num_q_per_kv + 2) - split_size = [kv_size * num_q_per_kv, kv_size, kv_size] - - num_query_groups_per_partition = config.num_key_value_heads - for chunk in tensor.chunk(num_query_groups_per_partition): - split_size = [ - kv_size * num_q_per_kv // num_query_groups_per_partition, - kv_size // num_query_groups_per_partition, - kv_size // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] - else: - return [tensor] - - def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: - state_dict = {} - layers_cum = 0 - if self.world_size > 1: - pipeline_cumsum = np.cumsum(self.pipeline_shards) - layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] - - print(f"{layers_cum=}") - for model_state_dict in model_state_dict_list: - layers_handled = 0 - keys = model_state_dict.keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - self._check_megatron_state_key(key) - hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." - if "model.layers." in hf_name: - local_layer_no = int(hf_name.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = hf_name.split(".") - new_key_list[2] = str(global_layer_no) - hf_name = ".".join(new_key_list) - else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) - - if "mlp.experts." in hf_name and ".weight" in hf_name: - name_prefix, expert_id = hf_name.split(".weight") - for proj in ["gate_up", "down"]: - if f"{proj}_proj" in hf_name: - hf_name = hf_name.replace( - f"mlp.experts.{proj}_proj.weight{expert_id}", - f"mlp.experts.{expert_id}.{proj}_proj.weight", - ) - - tensor = model_state_dict[key] - split_tensor = self._split_tensors( - key, tensor, self.hf_config, is_value_model=self.config.is_value_model - ) - - if len(split_tensor) == 1: - state_dict[hf_name] = split_tensor[0] - elif len(split_tensor) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], split_tensor, strict=True): - state_dict[hf_name.replace("qkv", n)] = d - elif len(split_tensor) == 2: - # split gate up - state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] - state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] - shape_info = ( - split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] - ) - print(f"converted {key} to {hf_name} with shape {shape_info}") - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def save_hf_model_and_tokenizer(self, merged_state_dict): - if self.world_size == 1: - return super().save_hf_model_and_tokenizer(merged_state_dict) - - from safetensors.torch import save_file - - layer_num = self.hf_config.num_hidden_layers - - # FIXME: make configurable - saves_per_layer = 1 if layer_num < 30 else 2 - saves_total = saves_per_layer * layer_num - saves_indexes = {} - - # calculate the layer start index and key chunks - layer_this_rank = self.pipeline_shards[self.rank] - pipeline_cumsum = np.cumsum(self.pipeline_shards) - layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] - keys = list(merged_state_dict.keys()) - keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer) - numel = 0 - - assert len(keys_chunk) == layer_this_rank * saves_per_layer, ( - f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}." - ) - - # save to model shards manually - target_dir = Path(self.config.target_dir) - for i, keys in enumerate(keys_chunk): - sd_to_save = {k: merged_state_dict[k] for k in keys} - numel += sum([sd_to_save[i].numel() for i in sd_to_save]) - save_idx = layer_start * saves_per_layer + i - save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors" - - save_file(sd_to_save, save_path) - for k in keys: - saves_indexes[k] = str(save_path.name) - - tensor = torch.tensor([numel]).to(get_device_name()) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - numel = tensor.cpu().item() - - all_save_indexes = [{} for _ in range(self.world_size)] - dist.all_gather_object(all_save_indexes, saves_indexes) - saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()} - if self.rank == 0: - with open(target_dir / "model.safetensors.index.json", "w") as f: - json.dump( - { - "metadata": { - "total_size": numel, - }, - "weight_map": saves_indexes, - }, - f, - indent=4, - ) - print(f"model saved to {target_dir} with {numel=}") - - self.model_config.save_pretrained(self.config.target_dir) - - processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def merge_and_save(self): - from verl.utils.megatron_utils import get_dist_checkpoint_path - - model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) - - model_state_dict = self._load_state_dicts(model_ckpt_path) - merged_state_dict = self._merge_state_dicts(model_state_dict) - del model_state_dict - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._validate_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - for name, loaded_weight in state_dict.items(): - # name = self._replace_name(original_name, self.params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if "lm_head.weight" in name: - if self.config.is_value_model or self.config.tie_word_embedding: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: - for m_name, v_name in name_mapping.items(): - if m_name not in megatron_name: - continue - - megatron_name = megatron_name.replace("decoder", "model") - param_name = megatron_name.replace(m_name, v_name) - - return param_name - - return None # Return None if no mapping found - - def cleanup(self): - torch.distributed.destroy_process_group() diff --git a/verl/models/llama/__init__.py b/verl/models/llama/__init__.py deleted file mode 100644 index 1ce90c5eb35..00000000000 --- a/verl/models/llama/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/llama/megatron/__init__.py b/verl/models/llama/megatron/__init__.py deleted file mode 100644 index fc851ea435f..00000000000 --- a/verl/models/llama/megatron/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .modeling_llama_megatron import ( - ParallelLlamaForCausalLM, - # rmpad with megatron - ParallelLlamaForCausalLMRmPad, - # rmpad with megatron and pipeline parallelism - ParallelLlamaForCausalLMRmPadPP, - ParallelLlamaForValueRmPad, - ParallelLlamaForValueRmPadPP, - # original model with megatron - ParallelLlamaModel, -) - -__all__ = [ - "ParallelLlamaForCausalLM", - "ParallelLlamaForCausalLMRmPad", - "ParallelLlamaForCausalLMRmPadPP", - "ParallelLlamaForValueRmPad", - "ParallelLlamaForValueRmPadPP", - "ParallelLlamaModel", -] diff --git a/verl/models/llama/megatron/checkpoint_utils/__init__.py b/verl/models/llama/megatron/checkpoint_utils/__init__.py deleted file mode 100644 index 1ce90c5eb35..00000000000 --- a/verl/models/llama/megatron/checkpoint_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py deleted file mode 100644 index dafecfdf084..00000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def fetch_params(module): - for param in module.parameters(): - torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _fetch_tensor(tensor, name) -> torch.Tensor: - """fetch tensor""" - nonlocal state_dict - if tensor is not None: - tensor.data.copy_(state_dict[name]) - - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """fetch gate_up tensor in tp shards""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if gate_name in state_dict and up_name in state_dict: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: - """fetch tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - layer_list = [] - if vpp_size is not None: - for vpp_rank in range(vpp_size): - num_layer_vpp_chunk = num_layer_per_pp // vpp_size - num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - else: - num_layer_this_model = num_layer_per_pp - offset = pp_rank * num_layer_per_pp - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - - for layer in layer_list: - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _fetch_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _fetch_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _fetch_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _fetch_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - else: - _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py deleted file mode 100644 index 2f65bc6b170..00000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == 0: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == 0: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=0, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py deleted file mode 100644 index 595efcde376..00000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): - """given TP,DP,PP rank to get the global rank.""" - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) - # We only support TP-DP-PP grouping, for correctness when resharding - return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2 - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - print_rank_0("collecting lm_head...") - - if is_value_model: - if pp_rank == pp_size - 1: - print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") - _broadcast_tensor( - gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, - "reward_head.weight", - src_pp_rank=pp_size - 1, - ) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - if dtype not in [torch.float16, torch.bfloat16, torch.float32]: - print(f'Unknown/unsupported dtype to save: {dtype}"') - exit(1) - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict diff --git a/verl/models/llama/megatron/layers/__init__.py b/verl/models/llama/megatron/layers/__init__.py deleted file mode 100644 index 352bc56086d..00000000000 --- a/verl/models/llama/megatron/layers/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .parallel_attention import ParallelLlamaAttention -from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad -from .parallel_linear import ( - LinearForLastLayer, - MergedColumnParallelLinear, - QKVParallelLinear, -) -from .parallel_mlp import ParallelLlamaMLP -from .parallel_rmsnorm import ParallelLlamaRMSNorm - -__all__ = [ - "LinearForLastLayer", - "MergedColumnParallelLinear", - "QKVParallelLinear", - "ParallelLlamaAttention", - "ParallelLlamaDecoderLayer", - "ParallelLlamaDecoderLayerRmPad", - "ParallelLlamaMLP", - "ParallelLlamaRMSNorm", -] diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py deleted file mode 100644 index 4f76b991abd..00000000000 --- a/verl/models/llama/megatron/layers/parallel_attention.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch -import torch.nn.functional as F -from einops import rearrange -from flash_attn.layers.rotary import apply_rotary_emb -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers import LlamaConfig -from transformers.utils import is_flash_attn_2_available - -from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): - def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): - super().__init__(dim, max_position_embeddings, base, device) - - self.factor = config.rope_scaling["factor"] # `8` in the original implementation - self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation - self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation - self.old_context_len = config.rope_scaling[ - "original_max_position_embeddings" - ] # `8192` in the original implementation - - low_freq_wavelen = self.old_context_len / self.low_freq_factor - high_freq_wavelen = self.old_context_len / self.high_freq_factor - - wavelen = 2 * math.pi / self.inv_freq - # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( - self.high_freq_factor - self.low_freq_factor - ) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class ParallelLlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config = config - self.megatron_config = megatron_config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # assign values after tp - tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) - assert self.num_key_value_heads % tp_size == 0, ( - f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" - f"{self.num_key_value_heads}, tp_size={tp_size}" - ) - - self.num_heads_per_tp = self.num_heads // tp_size - self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size - self.hidden_size_per_tp = self.hidden_size // tp_size - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - - # [self.q_size, self.k_size, self.v_size] - self.qkv_proj = QKVParallelLinear( - input_size=self.hidden_size, - num_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - bias=config.attention_bias, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - self.q_size = self.num_heads_per_tp * self.head_dim - self.k_size = self.num_key_value_heads_per_tp * self.head_dim - self.v_size = self.num_key_value_heads_per_tp * self.head_dim - - self.o_proj = tensor_parallel.RowParallelLinear( - input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, - bias=config.attention_bias, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" - scaling_type = self.config.rope_scaling[rope_type_key] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "llama3": - self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) - - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) - attn_output = self.o_proj(attn_output)[0] - return attn_output - - -""" -Remove padding Attention -- Using Flash-attn 2 -- Compatible with sequence parallel -""" - - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 - - -def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): - batch_size = position_ids.shape[0] - - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) - k = pad_input(k, indices, batch_size, sequence_length) - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) - k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) - - return q_embed, k_embed - - -# use flash-attn rotary embeddings with rmpad -# cos/sin shoudl be: (seq_length, rotary_dim / 2) -def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - return q_embed, k_embed - - -class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None, - ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel - - if self.megatron_config.sequence_parallel: - total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() - - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split( - [self.q_size, self.k_size, self.v_size], dim=-1 - ) # (total_nnz, 1, hidden_size) - - if self.megatron_config.sequence_parallel: - sequence_parallel_pad = total_nnz - cu_seqlens[-1] - total_nnz = cu_seqlens[-1] # total_nnz before sp padding - query_states = query_states[:total_nnz] - key_states = key_states[:total_nnz] - value_states = value_states[:total_nnz] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - - cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch - ) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, - # position_ids, indices, - - # TODO: llama does not have dropout in the config?? - # It is recommended to use dropout with FA according to the docs - # when training. - dropout_rate = 0.0 # if not self.training else self.attn_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_in_batch, - max_seqlen_k=max_seqlen_in_batch, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - - attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() - - # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled - # Here we need to repad - if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) - - attn_output_unpad = self.o_proj(attn_output_unpad)[0] - return attn_output_unpad diff --git a/verl/models/llama/megatron/layers/parallel_decoder.py b/verl/models/llama/megatron/layers/parallel_decoder.py deleted file mode 100644 index f46e9457c79..00000000000 --- a/verl/models/llama/megatron/layers/parallel_decoder.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import LlamaConfig - -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad -from .parallel_mlp import ParallelLlamaMLP -from .parallel_rmsnorm import ParallelLlamaRMSNorm - - -class ParallelLlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) - - self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Note: sequence parallel is hidden inside ColumnParallelLinear - # reduce scatter is hidden inside RowParallelLinear - - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # TODO: add sequence parallel operator all_gather here - - hidden_states = self.mlp(hidden_states) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs - - -class ParallelLlamaDecoderLayerRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) - - self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states # (total_nnz // sp, 1, hidden_size) - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) - # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - # shape changes same as attn - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs diff --git a/verl/models/llama/megatron/layers/parallel_linear.py b/verl/models/llama/megatron/layers/parallel_linear.py deleted file mode 100644 index 043726c46c3..00000000000 --- a/verl/models/llama/megatron/layers/parallel_linear.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py - -import torch -from megatron.core import tensor_parallel - - -class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.q_output_size = num_heads * head_dim - self.kv_output_size = num_key_value_heads * head_dim - self.head_dim = head_dim - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - input_size = self.input_size - output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - - super().__init__( - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.output_size = gate_ouput_size + up_output_size - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - super().__init__( - input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class LinearForLastLayer(torch.nn.Linear): - def __init__( - self, - input_size, - output_size, - *, - config, - bias=True, - ): - super().__init__(in_features=input_size, out_features=output_size, bias=bias) - self.sequence_parallel = config.sequence_parallel - if self.sequence_parallel: - self.weight.sequence_parallel = True - - def forward( - self, - input_, - weight=None, - runtime_gather_output=None, - ): - logits = super().forward(input_) - logits = logits.float() - if self.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits, None diff --git a/verl/models/llama/megatron/layers/parallel_mlp.py b/verl/models/llama/megatron/layers/parallel_mlp.py deleted file mode 100644 index 583a317eb6a..00000000000 --- a/verl/models/llama/megatron/layers/parallel_mlp.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers.activations import ACT2FN - -from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class ParallelLlamaMLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - - tp_size = mpu.get_tensor_model_parallel_world_size() - - self.gate_up_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - gate_ouput_size=self.intermediate_size, - up_output_size=self.intermediate_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - self.gate_size = self.intermediate_size // tp_size - - self.down_proj = tensor_parallel.RowParallelLinear( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - gate_up = self.gate_up_proj(x)[0] - gate, up = gate_up.split(self.gate_size, dim=-1) - return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/verl/models/llama/megatron/layers/parallel_rmsnorm.py deleted file mode 100644 index bc2e9ae36f0..00000000000 --- a/verl/models/llama/megatron/layers/parallel_rmsnorm.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numbers - -import torch -from apex.normalization.fused_layer_norm import fused_rms_norm_affine -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import LlamaConfig - -from verl.utils.megatron import sequence_parallel as sp_utils - - -class ParallelLlamaRMSNorm(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - if isinstance(config.hidden_size, numbers.Integral): - normalized_shape = (config.hidden_size,) - self.normalized_shape = torch.Size(normalized_shape) - self.weight = nn.Parameter(torch.ones(self.normalized_shape)) - self.variance_epsilon = config.rms_norm_eps - - if megatron_config.sequence_parallel: - sp_utils.mark_parameter_as_sequence_parallel(self.weight) - - def forward(self, hidden_states): - return fused_rms_norm_affine( - input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True, - ) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py deleted file mode 100644 index e8a7e2440e6..00000000000 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch LLaMA model with Megatron-style acceleration.""" - -from typing import Optional - -import torch -import torch.utils.checkpoint -from megatron.core import ModelParallelConfig, mpu, tensor_parallel -from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import CausalLMOutputWithPast - -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm - -""" -TODO: -1. Add weight initialization. Here we need to be careful on TP weight init. -2. Add sequence parallel -3. Load checkpoint from meta LLama pretrained checkpoint -""" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class ParallelLlamaModel(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (batch_size, seq_length) - attention_mask: attention_mask. shape (batch_size, seq_length) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - batch_size, seq_length = input_ids.shape - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLM(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.model = ParallelLlamaModel(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401, E402 - - -class ParallelLlamaModelRmPad(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - self.megatron_config = megatron_config - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLMRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - self._init_head(config) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - logits = self.lm_head(hidden_states)[0] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - batch_size, sequence_length = input_ids.shape - - # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids = sp_utils.pad_to_sequence_parallel(input_ids) - - input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = outputs - - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids, attention_mask, position_ids) - output.logits = torch.squeeze(output.logits, dim=-1) - return output - - -""" -Support pipeline parallelism -""" - - -class ParallelLlamaModelRmPadPP(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - This model definition supports pipeline parallelism. To support pp and vpp, - - This model only contains layer in this pp stage and vpp chunk - - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - self.megatron_config = megatron_config - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - else: - self.embed_tokens = None - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = megatron_config.pipeline_model_parallel_size - self.num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = megatron_config.virtual_pipeline_model_parallel_size - vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() - - if vpp_size is not None: - self.layers = nn.ModuleList() - self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size - self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) - else: - self.num_layer_this_model = self.num_layer_per_pp - offset = pp_rank * self.num_layer_per_pp - - self.layers = nn.ModuleList() - for i in range(self.num_layer_this_model): - layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) - self.layers.add_module(f"{i}", layer) - - if post_process: - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - else: - self.norm = None - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron - # so need to deal with it by handle here: - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - else: - # self.hidden_states should be passed by Megatron - hidden_states = self.input_tensor - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - if self.post_process: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLMRmPadPP(nn.Module): - def __init__( - self, - config: LlamaConfig, - megatron_config: ModelParallelConfig, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - ): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process - ) - assert share_embeddings_and_output_weights is False, ( - "Llama Model not supports sharing embedding and output weights" - ) - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - if post_process: - self._init_head(config) - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - assert len(input_tensor) == 1 - self.model.set_input_tensor(input_tensor[0]) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - # logits shape before forward_head hidden_states.shape: [4, 32, 4096] - logits = self.lm_head(hidden_states)[0] - # logits shape after forward_head logits.shape: [8, 32, 8] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - return logits - - def forward( - self, - # original input - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. - # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model - batch_size, sequence_length = input_ids.shape - # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) - - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - if self.post_process: - hidden_states = outputs - # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - else: - return outputs - - -class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - if self.post_process: - output.logits = torch.squeeze(output.logits, dim=-1) - return output - else: - return output diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py index a0f6e76f3f8..f6dddcc0121 100644 --- a/verl/models/mcore/__init__.py +++ b/verl/models/mcore/__init__.py @@ -17,16 +17,10 @@ get_mcore_forward_fn, get_mcore_forward_fused_fn, get_mcore_forward_no_padding_fn, - get_mcore_weight_converter, - hf_to_mcore_config, - init_mcore_model, ) __all__ = [ - "hf_to_mcore_config", - "init_mcore_model", "get_mcore_forward_fn", - "get_mcore_weight_converter", "get_mcore_forward_fused_fn", "get_mcore_forward_no_padding_fn", ] diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py deleted file mode 100644 index 93c349dd010..00000000000 --- a/verl/models/mcore/config_converter.py +++ /dev/null @@ -1,399 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# convert huggingface config to mcore transformer config - - -import warnings -from typing import TypeVar - -import torch -import torch.nn.functional as F -from megatron.core import parallel_state as mpu -from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from transformers import PretrainedConfig - -T = TypeVar("T", bound=TransformerConfig) - - -def _get_base_transformer_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> dict: - """ - Create a base TransformerConfig with common parameters across different model architectures. - TODO: (ycl) use dataclass or converter config? - - Args: - hf_config: HuggingFace model configuration - dtype: Data type for the model - override_transformer_config_kwargs: Additional parameters to override defaults - - Returns: - TransformerConfig with common parameters - """ - - # Common parallel state parameters - overlap_p2p_comm = ( - mpu.get_virtual_pipeline_model_parallel_world_size() is not None - and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 - ) - batch_p2p_comm = False - - # Base configuration with common parameters - base_config = { - # Model architecture parameters - "num_layers": hf_config.num_hidden_layers, - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_query_groups": hf_config.num_key_value_heads, - "ffn_hidden_size": hf_config.intermediate_size, - "attention_dropout": hf_config.attention_dropout, - "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), - "kv_channels": getattr(hf_config, "head_dim", None), - "layernorm_epsilon": hf_config.rms_norm_eps, - "add_bias_linear": True, - # Activation and normalization - "activation_func": F.silu, - "normalization": "RMSNorm", - "gated_linear_unit": True, - # Data types - "pipeline_dtype": dtype, - "params_dtype": dtype, - "bf16": dtype is torch.bfloat16, - # Parallel configuration - "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(), - "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(), - "expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(), - "expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(), - "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(), - "context_parallel_size": mpu.get_context_parallel_world_size(), - "overlap_p2p_comm": overlap_p2p_comm, - "batch_p2p_comm": batch_p2p_comm, - "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1, - # Common settings - "variable_seq_lengths": True, - "masked_softmax_fusion": True, - "moe_token_dispatcher_type": "alltoall", - } - - # Update with any provided overrides - # override_transformer_config_kwargs as kwargs shall never be none - base_config.update(override_transformer_config_kwargs) - - return base_config - - -def _get_mla_transformer_config( - hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs -) -> dict: - """ - Create a MLATransformerConfig with common parameters across different model architectures. - This is specifically for MLA models like DeepseekV3. - - Args: - hf_config: HuggingFace model configuration - mla_rope_config: MLA specific RoPE configuration - dtype: Data type for the model - override_transformer_config_kwargs: Additional parameters to override defaults - - Returns: - MLATransformerConfig with common parameters - """ - base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) - mla_config = { - # MLA specific parameters - "q_lora_rank": hf_config.q_lora_rank, - "kv_lora_rank": hf_config.kv_lora_rank, - "qk_head_dim": hf_config.qk_nope_head_dim, - "qk_pos_emb_head_dim": hf_config.qk_rope_head_dim, - "v_head_dim": hf_config.v_head_dim, - "rotary_base": hf_config.rope_theta, - "rotary_scaling_factor": mla_rope_config["factor"], - "rope_type": mla_rope_config["type"], - "max_position_embeddings": mla_rope_config["original_max_position_embeddings"], - "beta_fast": mla_rope_config["beta_fast"], - "beta_slow": mla_rope_config["beta_slow"], - "mscale": mla_rope_config["mscale"], - "mscale_all_dim": mla_rope_config["mscale_all_dim"], - } - - base_config.update(mla_config) - return base_config - - -def check_and_construct_configs(original_config: dict, cls: type[T]) -> T: - """ - Check and disable incompatible configurations for older Megatron version. - - Args: - original_config (dict): The original model configuration. - - Returns: - dict: The updated model configuration with incompatible settings disabled. - """ - removed_keys = [] - for key in original_config.keys(): - if not hasattr(cls, key): - removed_keys.append(key) - if removed_keys: - warnings.warn( - f"The following keys are not supported in the current Megatron version and will be removed: {removed_keys}", - stacklevel=2, - ) - for key in removed_keys: - original_config.pop(key) - - original_config = mapping_string_to_attn_backend(original_config) - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - print(f"Overridden {cls.__name__} init config: {original_config}") - return cls(**original_config) - - -def hf_to_mcore_config_dense( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # for LlamaForCausalLM or Qwen2ForCausalLM - qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False) - qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False - - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - add_qkv_bias=qkv_bias, - qk_layernorm=qk_layernorm, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - return check_and_construct_configs(args, TransformerConfig) - - -def hf_to_mcore_config_qwen2moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_router_bias_update_rate=0.001, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.num_experts, - moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - # moe_aux_loss_coeff=0.0, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_shared_expert_overlap=True, - moe_grouped_gemm=True, - moe_router_score_function="softmax", - # Other optimizations - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - # Qwen specific - moe_router_pre_softmax=True, - add_qkv_bias=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - return check_and_construct_configs(args, TransformerConfig) - - -def hf_to_mcore_config_mixtral( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - num_moe_experts=hf_config.num_local_experts, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - moe_router_topk=hf_config.num_experts_per_tok, - moe_router_pre_softmax=True, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_router_score_function="softmax", - moe_shared_expert_intermediate_size=None, # mixtral has no shared expert - moe_shared_expert_overlap=False, # mixtral has no shared expert - moe_ffn_hidden_size=hf_config.intermediate_size, - moe_router_bias_update_rate=0.001, - # moe_permute_fusion=True, # need TE 2.1+ - moe_grouped_gemm=True, - # Other optimizations - persist_layer_norm=True, - apply_rope_fusion=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - return check_and_construct_configs(args, TransformerConfig) - - -def hf_to_mcore_config_qwen3moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_router_bias_update_rate=0.001, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.num_experts, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - # moe_aux_loss_coeff=0.0, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_grouped_gemm=True, - moe_router_score_function="softmax", - # Other optimizations - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - # Qwen specific - moe_router_pre_softmax=False, - qk_layernorm=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - return check_and_construct_configs(args, TransformerConfig) - - -def hf_to_mcore_config_dpskv3( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> MLATransformerConfig: - # DeepseekV3ForCausalLM - from megatron.core.config import set_experimental_flag - from megatron.core.transformer.enums import AttnBackend - - set_experimental_flag(True) - - from .patch_v012 import apply_patch - - apply_patch() - - mla_rope_config = { - "beta_fast": 32, - "beta_slow": 1, - "factor": 1, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": 4096, - "type": "rope", - } - if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: - mla_rope_config.update(hf_config.rope_scaling) - moe_layer_freq = [1] * hf_config.num_hidden_layers - for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)): - moe_layer_freq[i] = 0 - - # disable MTP and quantization for now - if "num_nextn_predict_layers" in hf_config: - assert hf_config.num_nextn_predict_layers == 0, ( - "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" - ) - assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( - "quantization is not supported for now, please modify the config.json to remove quantization_config" - ) - - args: dict = _get_mla_transformer_config( - hf_config=hf_config, - mla_rope_config=mla_rope_config, - dtype=dtype, - # Additional parameters - use_cpu_initialization=False, - add_bias_linear=False, - attention_backend=AttnBackend.fused, - qk_layernorm=True, - # Standard MoE parameters - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_token_dispatcher_type="alltoall", - moe_router_bias_update_rate=0.001, - moe_router_enable_expert_bias=True, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.n_routed_experts, - moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, - moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), - moe_router_load_balancing_type="seq_aux_loss", - moe_shared_expert_overlap=True, - # moe_permute_fusion=True, # need TE 2.1+ - moe_grouped_gemm=True, - moe_router_score_function="sigmoid", - moe_router_pre_softmax=True, - moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, - moe_layer_freq=moe_layer_freq, - # mcore 0.12 moe - moe_router_dtype="fp64", - disable_bf16_reduced_precision_matmul=True, - # Other optimizations - # deallocate_pipeline_outputs=True, - # gradient_accumulation_fusion=True, - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - transformer_config = check_and_construct_configs(args, MLATransformerConfig) - # MTP - if "num_nextn_predict_layers" in hf_config: - transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers - transformer_config.mtp_loss_scaling_factor = 0.1 - - return transformer_config - - -def hf_to_mcore_config_qwen2_5_vl( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # Qwen2_5_VLForConditionalGeneration - - args = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - add_bias_linear=False, - # qwen specific - add_qkv_bias=True, - mrope_section=hf_config.rope_scaling["mrope_section"], - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - args = mapping_string_to_attn_backend(args) - return TransformerConfig(**args) - - -def hf_to_mcore_config_llama4( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # Llama4ForConditionalGeneration - raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") - - -def mapping_string_to_attn_backend(args: dict) -> dict: - if "attention_backend" in args and isinstance(args["attention_backend"], str): - from megatron.core.transformer.enums import AttnBackend - - args["attention_backend"] = AttnBackend[args["attention_backend"]] - return args diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py deleted file mode 100644 index 577ffc5ecf4..00000000000 --- a/verl/models/mcore/loader.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - -from .saver import _megatron_calc_global_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == src_rank: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.decoder.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == src_rank: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=src_rank, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - num_query_groups_per_partition = models[0].config.num_query_groups // tp_size - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) - k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) - v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) - total_size_per_head = total_size // num_query_groups_per_partition - for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) - ) - - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) - k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) - v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) - total_size_per_head = total_size // config.num_attention_heads - for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - layer_name = f"model.layers.{layer}" - print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - if f"{layer_name}.self_attn.q_norm.weight" in state_dict: - _broadcast_tensor( - sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_norm.weight", - ) - _broadcast_tensor( - sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.k_norm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - if f"{layer_name}.self_attn.q_proj.bias" in state_dict: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.output_layer.weight - - if is_value_model: - # if torch.distributed.get_rank() == src_rank: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - elif "score.weight" in state_dict and state_dict["score.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "score.weight") - print_rank_0("load lm_head from score weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - # else: - - # _broadcast_tensor(lm_head_weight, "lm_head.weight") - - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - pass - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index 09e169d7ea1..c480893911f 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -15,7 +15,7 @@ # limitations under the License. -from verl.utils.megatron_utils import unwrap_model +from megatron.core.utils import unwrap_model from .util import ( postprocess_bshd, diff --git a/verl/models/mcore/model_forward_1f1b_overlap.py b/verl/models/mcore/model_forward_1f1b_overlap.py index b8786e01f88..ed509885932 100644 --- a/verl/models/mcore/model_forward_1f1b_overlap.py +++ b/verl/models/mcore/model_forward_1f1b_overlap.py @@ -19,12 +19,11 @@ import torch from megatron.core.models.common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.utils import make_viewless_tensor +from megatron.core.utils import make_viewless_tensor, unwrap_model from torch import Tensor from verl.models.mcore.util import preprocess_packed_seqs from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy -from verl.utils.megatron_utils import unwrap_model from verl.utils.model import CausalLMOutputForPPO from .util import postprocess_packed_seqs, postprocess_packed_seqs_for_dict_output diff --git a/verl/models/mcore/model_forward_fused.py b/verl/models/mcore/model_forward_fused.py index bf5dfdb37fd..6ecacf82af5 100644 --- a/verl/models/mcore/model_forward_fused.py +++ b/verl/models/mcore/model_forward_fused.py @@ -25,12 +25,11 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from megatron.core.utils import deprecate_inference_params +from megatron.core.utils import deprecate_inference_params, unwrap_model from torch import Tensor from verl.models.mcore.util import preprocess_packed_seqs from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy -from verl.utils.megatron_utils import unwrap_model from verl.utils.model import CausalLMOutputForPPO from .util import postprocess_packed_seqs_for_dict_output diff --git a/verl/models/mcore/model_initializer.py b/verl/models/mcore/model_initializer.py deleted file mode 100644 index 49a30bc9e2c..00000000000 --- a/verl/models/mcore/model_initializer.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# use mcore transformer config to initialize the model -import inspect -from abc import ABC, abstractmethod - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec -from megatron.core.models.gpt.gpt_model import GPTModel - -from .config_converter import PretrainedConfig, TransformerConfig - - -class BaseModelInitializer(ABC): - """Base class for model initializers.""" - - def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): - self.tfconfig = tfconfig - self.hf_config = hf_config - self.has_vp_stage = inspect.signature(get_gpt_decoder_block_spec).parameters.get("vp_stage", None) is not None - - @abstractmethod - def get_transformer_layer_spec(self, vp_stage=None): - """Get the transformer layer specification. - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" - pass - - def get_rope_scaling_args(self) -> dict: - """Get rope scaling args.""" - rope_scaling_args = {} - if "rope_scaling" in self.hf_config: - if self.hf_config.rope_scaling is not None: - # assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] - return rope_scaling_args - - def initialize( - self, - pre_process: bool = True, - post_process: bool = True, - share_embeddings_and_output_weights: bool = False, - value: bool = False, - **extra_kwargs, - ) -> GPTModel: - """Initialize a GPT model with the given configuration. - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py - - Args: - pre_process (bool): include embedding layer. - post_process (bool): including an output layer. - share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared. - value (bool): add an extra linear layer for classification or regression. - - Returns: - GPTModel: An initialized GPT model instance - """ - vp_stage = extra_kwargs.get("vp_stage", None) - transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) - rope_scaling_args = self.get_rope_scaling_args() - mtp_block_spec = extra_kwargs.get("mtp_block_spec", None) - model = GPTModel( - config=self.tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=self.hf_config.vocab_size, - max_sequence_length=self.hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type="rope", - rotary_base=self.hf_config.rope_theta, - **rope_scaling_args, - mtp_block_spec=mtp_block_spec, - **({} if not self.has_vp_stage else {"vp_stage": vp_stage}), - ) - - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - model.output_layer = LinearForLastLayer( - input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig - ) - - return model - - -class DenseModel(BaseModelInitializer): - """Initializer for dense models like Llama and Qwen2.""" - - def get_transformer_layer_spec(self, vp_stage=None): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - - -class Qwen2MoEModel(BaseModelInitializer): - """Initializer for Qwen2 MoE models.""" - - def get_transformer_layer_spec(self, vp_stage=None): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - - # Patch layer spec for shared experts - for i in range(len(transformer_layer_spec.layer_specs)): - transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True - - return transformer_layer_spec - - def initialize(self, **kwargs): - # Qwen default freeze_moe_router: true - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class MixtralModel(BaseModelInitializer): - """Initializer for Mixtral models.""" - - def get_transformer_layer_spec(self, vp_stage=None): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - return transformer_layer_spec - - def initialize(self, **kwargs): - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", False) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class Qwen3MoEModel(BaseModelInitializer): - """Initializer for Qwen3 MoE models.""" - - def get_transformer_layer_spec(self, vp_stage=None): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - return transformer_layer_spec - - def initialize(self, **kwargs): - # Qwen default freeze_moe_router: true - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class DeepseekV3Model(BaseModelInitializer): - """Initializer for DeepseekV3 models.""" - - def get_transformer_layer_spec(self, vp_stage=None): - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - return transformer_layer_spec - - def get_rope_scaling_args(self) -> dict: - """Get rope scaling args.""" - rope_scaling_args = {} - return rope_scaling_args - - def initialize( - self, - **kwargs, - ): - vp_stage = kwargs.get("vp_stage", None) - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - self.tfconfig.moe_router_load_balancing_type = "none" - # MTP - if self.tfconfig.mtp_num_layers is not None and self.tfconfig.mtp_num_layers > 0: - transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) - mtp_block_spec = get_gpt_mtp_block_spec( - self.tfconfig, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage - ) - kwargs["mtp_block_spec"] = mtp_block_spec - - model = super().initialize(**kwargs) - if freeze_moe_router: - for layer in model.decoder.layers: - if hasattr(layer.mlp, "router"): - layer.mlp.router.weight.requires_grad = False - return model - - -class Qwen25VLModel(BaseModelInitializer): - """Initializer for Qwen2.5 VL models.""" - - def get_transformer_layer_spec(self, vp_stage=None): - extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage} - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs) - return transformer_layer_spec - - def initialize( - self, - pre_process=None, - post_process=None, - share_embeddings_and_output_weights=False, - value=False, - **extra_kwargs, - ): - tfconfig = self.tfconfig - hf_config = self.hf_config - # Qwen2_5_VLForConditionalGeneration - from copy import deepcopy - - transformer_layer_spec = self.get_transformer_layer_spec() - - from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear - from megatron.core.models.gpt.moe_module_specs import MLPSubmodules - from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec - - from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config - - vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) - vision_transformer_config.pipeline_model_parallel_size = 1 - vision_transformer_config.first_pipeline_num_layers = None - - vision_projection_config = get_vision_projection_config( - deepcopy(tfconfig), - vision_transformer_config.hidden_size, - spatial_merge_size=hf_config.vision_config.spatial_merge_size, - ) - vision_projection_layer_spec = MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ) - vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() - - qwen25_vl_model = Qwen2_5VLModel( - language_transformer_config=tfconfig, - language_transformer_layer_spec=transformer_layer_spec, - language_vocab_size=hf_config.vocab_size, - language_max_sequence_length=hf_config.max_position_embeddings, - vision_transformer_config=vision_transformer_config, - vision_transformer_layer_spec=vision_transformer_layer_spec, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_layer_spec, - vision_projection_type="mlp", - language_rotary_base=hf_config.rope_theta, - pre_process=pre_process, - post_process=post_process, - add_decoder=True, - add_encoder=True, - parallel_output=True, - language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, - ) - - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - qwen25_vl_model.language_model.output_layer = LinearForLastLayer( - input_size=tfconfig.hidden_size, output_size=1, config=tfconfig - ) - - return qwen25_vl_model diff --git a/verl/models/mcore/patch_v012.py b/verl/models/mcore/patch_v012.py deleted file mode 100644 index d54a3eb346d..00000000000 --- a/verl/models/mcore/patch_v012.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# there is some bug in mcore 0.12, so we need to patch it -# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None - - -def apply_patch(): - import torch - from megatron.core import parallel_state, tensor_parallel - from megatron.core.transformer.multi_latent_attention import ( - MLASelfAttention, - apply_rotary_pos_emb, - deprecate_inference_params, - gather_from_sequence_parallel_region, - gather_from_tensor_model_parallel_region, - scatter_to_sequence_parallel_region, - ) - - def patch_get_query_key_value_tensors( - self, - hidden_states, - key_value_states=None, - position_ids=None, - packed_seq_params=None, - inference_context=None, - *, - inference_params=None, - ): - """ - Derives `query`, `key` and `value` tensors from `hidden_states`. - """ - # s = sequence length, b = batch size, h = hidden size, n = num attention heads - # Attention heads [s, b, n*h] - assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # ========================================= - # Prepare RoPE and seqlen related params - # ========================================= - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, None, hidden_states, self.config, packed_seq_params - ) - - # rotary_pos_emb:[s, b, 1, 64] - mscale = 1.0 - if self.config.rope_type == "rope": - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) - else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) - - # ========================================= - # QKV down projection and layernorm - # ========================================= - if self.config.q_lora_rank is not None: - # if linear_q_down_proj is ColumnParallelLinear: - # q_compressed: [s, b, q_lora_rank / TP] - # elif linear_q_down_proj is Linear: - # q_compressed: [s / TP, b, q_lora_rank] - q_compressed, _ = self.linear_q_down_proj(hidden_states) - - # When output is sharded (ColumnParallelLinear), two things are needed to be - # identical to a normal Linear. - # 1. Manually gather output to restore output dim q_lora_rank; - # 2. Scatter sequence back to s / TP if sequence-parallel since it was - # gathered by ColumnParallelLinear. - if q_compressed.size(-1) != self.config.q_lora_rank: - q_compressed = gather_from_tensor_model_parallel_region(q_compressed) - if self.config.sequence_parallel: - q_compressed = scatter_to_sequence_parallel_region(q_compressed) - - q_compressed = self.q_layernorm(q_compressed) - else: - q_compressed = hidden_states - - # if linear_kv_down_proj is ColumnParallelLinear: - # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] - # elif linear_kv_down_proj is Linear: - # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] - kv_combined, _ = self.linear_kv_down_proj(hidden_states) - if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: - # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] - kv_combined = gather_from_tensor_model_parallel_region(kv_combined) - # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 - ) - if self.config.sequence_parallel: - # kv_compressed:[s / TP, b, kv_lora_rank] - kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) - else: - # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 - ) - if parallel_state.get_tensor_model_parallel_world_size() > 1: - # k_pos_emb: [s, b, qk_pos_emb_head_dim] - k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) - - kv_compressed = self.kv_layernorm(kv_compressed) - - # ========================================= - # QKV up projection and RoPE apply - # ========================================= - def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): - if self.config.q_lora_rank is not None: - q, _ = self.linear_q_up_proj(q_compressed) - else: - # hidden_states:[s, b, 2048], q: [s, b, n * 192] - q, _ = self.linear_q_proj(q_compressed) - - q_len, bsz, _ = q.size() - - # q: [s, b, n, 192] - q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) - - # kv: [s, b, 2048] - kv, _ = self.linear_kv_up_proj(kv_compressed) - - # kv: [s, b, n, 256] - kv = kv.view( - q_len, - bsz, - self.num_attention_heads_per_partition, - self.config.qk_head_dim + self.config.v_head_dim, - ) - - if inference_context is not None: - # add offset to the sequence start for inference - sequence_start = inference_context.sequence_len_offset - sequence_end = sequence_start + q_len - rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] - else: - # Shorten rotary_pos_emb to the sequence length when inference_params - # is not provided. This makes sure we can run forward directly with - # any sequence length. During training, the sequence length is always - # the full rotary_pos_emb length. - rotary_pos_emb = rotary_pos_emb[0:q_len] - - # [s, b, 64] -> [s, b, 1, 64] - k_pos_emb = torch.unsqueeze(k_pos_emb, 2) - - # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] - q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) - - # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] - k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) - - if packed_seq_params is not None: - cu_seqlens_q = packed_seq_params.cu_seqlens_q - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv - q_pos_emb = q_pos_emb.squeeze(1) - k_pos_emb = k_pos_emb.squeeze(1) - q_no_pe = q_no_pe.squeeze(1) - k_no_pe = k_no_pe.squeeze(1) - value = value.squeeze(1) - else: - cu_seqlens_q = cu_seqlens_kv = None - - # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64] - q_pos_emb = apply_rotary_pos_emb( - q_pos_emb, - rotary_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_q, - mscale=mscale, - ) - k_pos_emb = apply_rotary_pos_emb( - k_pos_emb, - rotary_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_kv, - mscale=mscale, - ) - - # query: [s, b, n, 192] - query = torch.cat([q_no_pe, q_pos_emb], dim=-1) - if packed_seq_params is not None: - k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) - key = torch.cat([k_no_pe, k_pos_emb], dim=-1) - else: - # key: [s, b, n, 192] - k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) - key = torch.cat([k_no_pe, k_pos_emb], dim=-1) - - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - return query, key, value - - if self.recompute_up_proj: - self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() - query, key, value = self.qkv_up_checkpoint.checkpoint( - qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb - ) - else: - query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) - - return query, key, value - - MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors diff --git a/verl/models/mcore/qwen2_5_vl/__init__.py b/verl/models/mcore/qwen2_5_vl/__init__.py deleted file mode 100644 index 8842d0249e1..00000000000 --- a/verl/models/mcore/qwen2_5_vl/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .model import Qwen2_5VLModel -from .vision_config import get_vision_model_config, get_vision_projection_config - -__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/verl/models/mcore/qwen2_5_vl/attention.py b/verl/models/mcore/qwen2_5_vl/attention.py deleted file mode 100644 index 2a87a053c59..00000000000 --- a/verl/models/mcore/qwen2_5_vl/attention.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core.transformer.attention import * - -from .rope_utils import apply_rotary_pos_emb_absolute - - -class Qwen2_5VLSelfAttention(SelfAttention): - """ - Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute - instead of apply_rotary_pos_emb - """ - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - key_value_states: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, - rotary_pos_cos: Optional[Tensor] = None, - rotary_pos_sin: Optional[Tensor] = None, - attention_bias: Optional[Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[int] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - rotary_pos_cos_sin: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Perform a forward pass through the attention module. - - Args: - hidden_states (Tensor): Hidden states. - attention_mask (Tensor): Attention mask. - key_value_states (Optional[Tensor]): Key/value states (for cross attention). - inference_context (Optional[BaseInferenceContext]): Inference context that manages - KV cache. - rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary - embedding tensor(s). - rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. - rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. - attention_bias (Optional[Tensor]): Attention bias. - packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. - sequence_len_offset (Optional[int]): Sequence length offset used for - inference CUDA graphs. - - Return: - (Tuple[Tensor, Tensor]) Attention output and bias. - - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - if inference_context and inference_context.is_dynamic_batching(): - assert flash_decode_and_prefill_kernel is not None, ( - "Internal use only: install package `nvidia_chunked_flash_attn`." - ) - - # hidden_states: [sq, b, h] - if self.config.flash_decode and not self.training and inference_context is not None: - rotary_pos_emb = None - else: - assert rotary_pos_cos is None and rotary_pos_sin is None - - # For self attention we just duplicate the rotary_pos_emb if it isn't already - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb,) * 2 - - # ===================== - # Query, Key, and Value - # ===================== - # Get the query, key and value tensors based on the type of attention - - # self or cross attn. - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - # =================================================== - # Adjust key, value, and rotary_pos_emb for inference - # =================================================== - - # This branch only runs in the decode phase of flash decoding and returns after the linear - # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. - if ( - self.config.flash_decode - and inference_context is not None - and inference_context.is_decode_only() - and not self.training - and rotary_pos_cos is not None - ): - assert self.layer_number in inference_context.key_value_memory_dict - assert inference_context.sequence_len_offset is not None - inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] - output = self.flash_decode( - sequence_len_offset=sequence_len_offset, - query_layer=query, - key_layer=key, - value_layer=value, - inference_key_memory=inference_key_memory, - inference_value_memory=inference_value_memory, - rotary_cos=rotary_pos_cos, - rotary_sin=rotary_pos_sin, - ) - out = output.transpose(0, 1).contiguous() - context_layer = out.view(out.size(0), out.size(1), -1) - output, bias = self.linear_proj(context_layer) - return output, bias - - # Use latest mcore 0.13 API and forward-compatible with previous versions. - outputs = self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - ) - - query, key, value, rotary_pos_emb, attn_mask_type = outputs[:5] - - if packed_seq_params is not None: - query = query.squeeze(1) - key = key.squeeze(1) - value = value.squeeze(1) - - # ================================================ - # relative positional embedding (rotary embedding) - # ================================================ - if rotary_pos_emb is not None and not self.config.flash_decode: - q_pos_emb, k_pos_emb = rotary_pos_emb - - if packed_seq_params is not None: - if packed_seq_params.cu_seqlens_q_padded is not None: - cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded - else: - cu_seqlens_q = packed_seq_params.cu_seqlens_q - if packed_seq_params.cu_seqlens_kv_padded is not None: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded - else: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv - else: - cu_seqlens_q = cu_seqlens_kv = None - - if q_pos_emb is not None: - # TODO VIJAY: simplify - if inference_context is None or inference_context.is_static_batching(): - query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) - else: - query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) - if k_pos_emb is not None: - key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) - - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention and self.training: - core_attn_out = self._checkpointed_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - else: - if inference_context is None or inference_context.is_static_batching(): - # Static batching attention kernel. - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - - else: - # Dynamic batching attention kernel. - q, k, v = (query, key, value) - cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() - cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() - - core_attn_out = self.flash_decode_and_prefill( - q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths - ) - core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) - core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") - - if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": - # reshape to same output shape as unpacked case - # (t, np, hn) -> (t, b=1, h=np*hn) - # t is the pack size = sum (sq_i) - # note that batch is a dummy dimension in the packed case - core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.linear_proj(core_attn_out) - - return output, bias diff --git a/verl/models/mcore/qwen2_5_vl/model.py b/verl/models/mcore/qwen2_5_vl/model.py deleted file mode 100644 index 91118edfb6c..00000000000 --- a/verl/models/mcore/qwen2_5_vl/model.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -import torch -from megatron.core import InferenceParams, mpu, tensor_parallel -from megatron.core.models.gpt.gpt_model import GPTModel - -# from .transformer_config import Qwen2VLTransformerConfig -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig - -from verl.models.mcore.util import preprocess_packed_seqs - -from .attention import Qwen2_5VLSelfAttention -from .vision_model import Qwen2_5VisionModel - - -# Note: This is under development and may be missing features. -class Qwen2_5VLModel(MegatronModule): - """Qwen2.5VL multi-modal model. - - Args: - language_transformer_config (TransformerConfig): Transformer config for the language model. - language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the - language model. - language_vocab_size (int): Language model vocabulary size. - language_max_sequence_length (int): Language model maximum sequence length. This is used for - positional embedding. - vision_transformer_config (TransformerConfig): Transformer config for the vision model. - vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the - vision model. - vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to - language model inputs. - vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision - projection. - vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This - is typically True for training and False for inference. - language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings - in the language model. Defaults to 1.0. - pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). - Defaults to True. - post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline - parallelism). Defaults to True. - add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. - When we use pipelining, the encoder - will live on only a subset of the pipeline stages (specifically, only the first stage). - add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. - When we use pipelining, the decoder - will live on only a subset of the pipeline stages (specifically, every stage after the first one). - img_h (int): The height of each image that the ViT will see. - img_w (int): The width of each image that the ViT will see. - patch_dim (int): The size of each patch side. - img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be - inserted. Defaults to 0. - """ - - def __init__( - self, - language_transformer_config: TransformerConfig, - language_transformer_layer_spec: ModuleSpec, - language_vocab_size: int, - language_max_sequence_length: int, - vision_transformer_config: TransformerConfig, - vision_transformer_layer_spec: ModuleSpec, - vision_projection_config: TransformerConfig, - vision_projection_layer_spec: ModuleSpec, - vision_projection_type: str = "mlp", - parallel_output: bool = True, - language_rotary_percent: float = 1.0, - pre_process: bool = True, - post_process: bool = True, - add_encoder: bool = True, - add_decoder: bool = True, - language_rotary_base: int = 10000, - fp16_lm_cross_entropy: bool = False, - language_share_embeddings_and_output_weights: bool = False, - image_token_id: int = 151655, - video_token_id: int = 151656, - ) -> None: - super().__init__(config=language_transformer_config) - - # patch self_attention to use qwen2_5_vl attention - vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention - for layer_spec in language_transformer_layer_spec.layer_specs: - layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention - - logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") - - self.pre_process = pre_process - self.post_process = post_process - self.add_encoder = add_encoder - self.add_decoder = add_decoder - - self.encoder_hidden_state = None - self.vision_model = None - self.vision_projection = None - self.language_model = None - self.image_token_id = image_token_id - self.video_token_id = video_token_id - - self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size - - # This attribute is needed to check if an all-reduce is required - # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - self.share_embeddings_and_output_weights = False - if self.pre_process: - self.vision_model = Qwen2_5VisionModel( - vision_transformer_config, - vision_transformer_layer_spec, - vision_projection_config, - vision_projection_layer_spec, - projection_type=vision_projection_type, - pre_process=True, - post_process=True, - ) - - self.language_model = GPTModel( - config=language_transformer_config, - transformer_layer_spec=language_transformer_layer_spec, - vocab_size=language_vocab_size, - max_sequence_length=language_max_sequence_length, - parallel_output=parallel_output, - position_embedding_type="mrope", - rotary_percent=language_rotary_percent, - pre_process=self.pre_process, - post_process=self.post_process, - rotary_base=language_rotary_base, - fp16_lm_cross_entropy=fp16_lm_cross_entropy, - share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, - scatter_embedding_sequence_parallel=False, - ) - assert mpu.get_context_parallel_world_size() <= 1, "please use mbridge for qwen2_5_vl with context parallelism" - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - - def shared_embedding_or_output_weight(self): - """This is a convenience method to surface the language model's word embeddings, which is - necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" - if self.add_decoder: - return self.language_model.shared_embedding_or_output_weight() - return None - - def set_input_tensor(self, input_tensor) -> None: - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" - - if self.pre_process: - self.encoder_hidden_state = input_tensor[0] - else: - self.language_model.set_input_tensor(input_tensor[0]) - - def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): - """Freeze model modules. - - Make specific modules non-trainable by setting requires_grad to False for the module's parameters. - - Args: - freeze_language_model (bool): Freeze the language model module. - freeze_vision_model (bool): Freeze the vision model module. - freeze_vision_projection (bool): Freeze the vision projection module. - """ - modules = [] - if freeze_language_model and self.language_model is not None: - modules.append(self.language_model) - if freeze_vision_model and self.vision_model is not None: - modules.append(self.vision_model) - if freeze_vision_projection and self.vision_projection is not None: - modules.append(self.vision_projection) - - for module in modules: - for param in module.parameters(): - param.requires_grad = False - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - pixel_values: torch.Tensor = None, - pixel_values_videos: torch.Tensor = None, - image_grid_thw: torch.Tensor = None, - video_grid_thw: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the Qwen2VL model. - ### there is a workaround for supporting sequence packing with context parallelism - # cp split with sequence packing will make model lose vision token information, so we need to keep - # the original input_ids and pack them after vision embedding is calculated, - # cooporate with verl's models/mcore/model_forward.py - # pack the combined_embeddings to thd here, we check if packed_seq_params is None to determine if - # we need to pack the combined_embeddings to thd - # this function needs the position_ids and attention_mask in BSHD format, no matter use packed_seq or not - - Args: - image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. - input_ids (torch.Tensor): input text ids [batch, text_seq_len]. - position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. - attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, - combined_seq_len]. - labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. - inference_params (InferenceParams): Inference-time parameters including KV cache. - - video_start_index: - 0 -- all video - len(video_seq) -- all image - others -- mixture - *_input_mask: should not be None in the first PP stage - Returns: - output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape - [b, s, vocab_size]. - """ - video_start_index = 0 - vision_grid_thw = None - vision_data = None - if image_grid_thw is not None: - image_mask = input_ids == self.image_token_id - vision_grid_thw = image_grid_thw - vision_data = pixel_values - video_start_index = image_mask.sum().item() - if video_grid_thw is not None: - video_mask = input_ids == self.video_token_id - if vision_grid_thw is not None: - vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) - vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) - else: - vision_grid_thw = video_grid_thw - vision_data = pixel_values_videos - use_inference_kv_cache = ( - inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict - ) - if use_inference_kv_cache: - raise NotImplementedError() - - if self.pre_process: - vision_embeds = None - if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: - vision_embeds = self.vision_model( - vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) - grid_thw=vision_grid_thw, # should provided in each EPP stage - ) - - # If running inference, the language model KV cache will be updated for image token positions. - # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. - if inference_params is not None: - raise NotImplementedError() - # inference_params.key_value_memory_dict["image_tokens_count"] = ( - # vision_embeddings.shape[0] - # ) - - # If running inference, we can skip image token computation if they were computed already earlier - # for this sample. - if use_inference_kv_cache: - language_embeddings: torch.Tensor = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - # NOTE: why not cat here? is it the combined embeddings useless? - combined_embeddings = language_embeddings - elif vision_embeds is not None: - if video_start_index == 0: - image_embeds = None - video_embeds = vision_embeds - elif video_start_index == vision_embeds.shape[0]: - image_embeds = vision_embeds - video_embeds = None - elif 0 < video_start_index < vision_embeds.shape[0]: - image_embeds = vision_embeds[:video_start_index] - video_embeds = vision_embeds[video_start_index:] - else: - raise ValueError( - f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " - f"{video_start_index}" - ) - - combined_embeddings = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - - if image_embeds is not None or video_embeds is not None: - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() - if image_embeds is not None: - image_mask = (input_ids == self.image_token_id).contiguous() - if image_mask.sum() > 0: - combined_embeddings = combined_embeddings.clone() - combined_embeddings[image_mask] = image_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device - ) - if video_embeds is not None: - video_mask = (input_ids == self.video_token_id).contiguous() - if video_mask.sum() > 0: - combined_embeddings = combined_embeddings.clone() - combined_embeddings[video_mask] = video_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device - ) - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() - - else: - combined_embeddings = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - - if packed_seq_params is not None: - combined_embeddings = ( - preprocess_packed_seqs( - combined_embeddings.transpose(0, 1).contiguous(), attention_mask, pre_process=True - )[0] - .transpose(0, 1) - .contiguous() - ) - if self.config.sequence_parallel: - combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) - combined_embeddings = combined_embeddings.contiguous() - else: - combined_embeddings = None - from .rope_utils import get_rope_index - - # BSHD - position_ids, _ = get_rope_index( - input_ids, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - attention_mask=attention_mask, - ) - # THD - if packed_seq_params is not None: - position_ids = ( - preprocess_packed_seqs(position_ids.permute(1, 2, 0), attention_mask, pre_process=True)[0] - .permute(2, 0, 1) - .contiguous() - ) - attention_mask = None - - output = self.language_model( - input_ids=None, - position_ids=position_ids, # None in encoder - attention_mask=attention_mask, # None in encoder - decoder_input=combined_embeddings, # only not None in the first decoder PP stage - labels=labels, # only not None in the last decoder PP stage - # inference_params=inference_params, # currently always None - packed_seq_params=packed_seq_params, # currently always None - **(extra_block_kwargs or {}), - **kwargs, - ) - - return output diff --git a/verl/models/mcore/qwen2_5_vl/rope_utils.py b/verl/models/mcore/qwen2_5_vl/rope_utils.py deleted file mode 100644 index fadc74daabe..00000000000 --- a/verl/models/mcore/qwen2_5_vl/rope_utils.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import logging -from typing import Optional - -import torch -from megatron.core.models.common.embeddings.rope_utils import * -from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd -from torch import Tensor - -logger = logging.getLogger(__name__) - - -# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index -def get_rope_index( - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, -): - """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. - - Examples: - - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embedding for text part. - - Examples: - - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each - second. - tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal - tokens" are conceptually packed into a one-second interval of the video. - In this case, we have 25 tokens per second. So each second of the video will be - represented with 25 separate time points. It essentially defines the temporal - granularity. - temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. - interval: The step size for the temporal position IDs, calculated as tokens_per_second * - temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be - have a difference of 50 in the temporal position IDs. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] - vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] - vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] - Here we calculate the text start position_ids as the max vision position_ids plus 1. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) - """ - spatial_merge_size = 2 - tokens_per_second = 2 - image_token_id = 151655 - video_token_id = 151656 - vision_start_token_id = 151652 - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - second_per_grid_t = 0 - image_index += 1 - remain_images -= 1 - ed = ed_image - - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - range_tensor = torch.arange(llm_grid_t).view(-1, 1) - expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) - - time_tensor = expanded_range * second_per_grid_t * tokens_per_second - - time_tensor_long = time_tensor.long() - t_index = time_tensor_long.flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) - - return position_ids, mrope_position_deltas - - -def apply_rotary_pos_emb_thd_absolute( - t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False -) -> Tensor: - """A baseline implementation of applying RoPE for `thd` format. - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - - Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. - """ - return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) - - -def apply_rotary_pos_emb_absolute( - t: Tensor, - freqs: Tensor, - config: TransformerConfig, - cu_seqlens: Optional[Tensor] = None, -): - """ - Reroute to the appropriate apply_rotary_pos_emb function depending on - bshd (conventional) / thd (packed seq) format - - In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] - """ - - if config.apply_rope_fusion: - if cu_seqlens is None: - # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 - if freqs.shape[1] > 1: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) - else: - return fused_apply_rotary_pos_emb(t, freqs) - else: - # NOTE: as expected, thd format can use bshd - return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) - else: - if cu_seqlens is None: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) - else: - return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/verl/models/mcore/qwen2_5_vl/vision_config.py b/verl/models/mcore/qwen2_5_vl/vision_config.py deleted file mode 100644 index 0631c90f616..00000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core import parallel_state -from megatron.core.transformer import TransformerConfig - - -def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: - # Given a Transformer Config from decoder, build vision encoder config - # diff: out_hidden_size & intermediate_size - - # mlp: hidden_size -> intermediate_size -> embed_dim, silu - # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on - if config.num_layers in [28, 36]: - config.ffn_hidden_size = 3420 - else: - config.ffn_hidden_size = 3456 - - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth - else: - config.num_layers = 32 # depth - config.num_attention_heads = 16 # num_heads - config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) - config.add_qkv_bias = True # qkv_proj in attn has bias - config.hidden_size = 1280 # hidden_size - config.hidden_dropout = 0.0 - config.attention_dropout = 0.0 - - # config.gated_linear_unit = False # no gated - # config.activation_func = quick_gelu # hidden_act - config.kv_channels = config.hidden_size // config.num_attention_heads - config.num_query_groups = config.num_attention_heads # no GQA - config.layernorm_zero_centered_gamma = False # False - config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) - config.bias_activation_fusion = False # no swiglu, set false - config.bias_dropout_fusion = False # no dropout, set false - config.attention_softmax_in_fp32 = True # use True - # config.normalization = 'LayerNorm' # use RMSNorm - config.seq_length = 1 - - config.tp_comm_overlap = False - config.sequence_parallel = False - config.temporal_patch_size = 2 - config.patch_size = 14 - config.in_channels = 3 - config.spatial_merge_size = 2 - - config.fullatt_block_indexes = [7, 15, 23, 31] - config._qwen2_5_vl_window_size = 112 - return config - - -def get_vision_projection_config( - config: TransformerConfig, embed_dim: int, spatial_merge_size: int -) -> TransformerConfig: - # merger: - # context_dim = hidden_size * merge_size**2 - # out_hidden_size = hidden_size - # context_dim -> context_dim -> out_hidden_size - # MLP: - # input_size -> ffn_hidden_size -> hidden_size - # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) - config.gated_linear_unit = False - config.bias_activation_fusion = False - config.add_bias_linear = True - config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) - config.activation_func = torch.nn.functional.gelu - config.tp_comm_overlap = False - config.sequence_parallel = False - return config diff --git a/verl/models/mcore/qwen2_5_vl/vision_model.py b/verl/models/mcore/qwen2_5_vl/vision_model.py deleted file mode 100644 index 06b4fd32806..00000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_model.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import InferenceParams -from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from torch import nn -from torch.nn import functional as F - -from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock - - -# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - return hidden_states - - -# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs.float() - - -class Qwen2_5VisionModel(VisionModule): - """Qwen2.5 ViT vision model. - - Args: - transformer_config (TransformerConfig): Transformer config. - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. - ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. - add_class_token (bool, optional): Include a class token. Defaults to True. - class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. - patch_dim (int): Image patch size. - img_h (int): Input image height. - img_w (int): Input image width. - """ - - def __init__( - self, - transformer_config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - projection_config: TransformerConfig, - projection_layer_spec: ModuleSpec, - projection_type: str = "mlp", - pre_process: bool = True, - post_process: bool = False, - ) -> None: - super().__init__(config=transformer_config) - - self.spatial_merge_size = transformer_config.spatial_merge_size - - embed_dim = transformer_config.hidden_size - num_heads = transformer_config.num_attention_heads - temporal_patch_size = transformer_config.temporal_patch_size - patch_size = transformer_config.patch_size - in_channels = transformer_config.in_channels - - self.patch_size = transformer_config.patch_size - self.fullatt_block_indexes = transformer_config.fullatt_block_indexes - self.window_size = transformer_config._qwen2_5_vl_window_size - self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - - self.max_sequence_length = transformer_config.seq_length - self.patch_embed = PatchEmbed( - patch_size=patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - ) - - head_dim = embed_dim // num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - - self.model_type = ModelType.encoder_or_decoder - self.pre_process = pre_process - self.post_process = post_process - - # Transformer layers. - # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting - # pipeline parallelism. - # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. - self.decoder = TransformerBlock( - config=transformer_config, - spec=transformer_layer_spec, - pre_process=self.pre_process, - post_process=self.post_process, - post_layer_norm=True, - ) - - self.merge_hidden_size = projection_config.ffn_hidden_size - self.square_merge_size = self.merge_hidden_size // embed_dim - - if self.post_process: - self.projection = MultimodalProjector( - projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size - ) - else: - self.projection = None - - self.input_tensor = None - - def set_input_tensor(self, input_tensor: torch.Tensor) -> None: - """Sets input tensor to the model. - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - if self.pre_process: # always True - self.input_tensor = input_tensor - else: - raise NotImplementedError() - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - - def forward( - self, - vision_data: Optional[torch.Tensor], - grid_thw: torch.Tensor, - inference_params: Optional[InferenceParams] = None, - extra_block_kwargs: dict = None, - ) -> torch.Tensor: - """Forward function of the Qwen2 Vision Model. This function passes the input tensors - through the embedding layer and then the transformer. - - Args: - x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] - grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame - packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend - - Returns: - x (torch.Tensor): output after final transformer block of shape [b, s, h]. - """ - assert grid_thw is not None - assert self.input_tensor is None - assert inference_params is None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - vision_data = self.patch_embed(vision_data) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=vision_data.device, - dtype=torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - - seq_len, _ = vision_data.size() - vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - vision_data = vision_data[window_index, :, :] - vision_data = vision_data.reshape(seq_len, 1, -1) - - rotary_pos_emb = self.rot_pos_emb(grid_thw) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) - - hidden_states = self.decoder( - hidden_states=vision_data, - attention_mask=None, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), - packed_seq_params_full=self.build_packed_seq_params(grid_thw), - fullatt_block_indexes=self.fullatt_block_indexes, - **(extra_block_kwargs or {}), - ) - - hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) - reverse_indices = torch.argsort(window_index) - return hidden_states[reverse_indices, :] - - def build_packed_seq_params( - self, - grid_thw: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor] = None, - ) -> PackedSeqParams: - # NOTE: each frame is a sequence (rather than each grid) - if grid_thw is not None: - seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) - cu_seqlens = seqlens.cumsum(dim=0) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() - else: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - - max_seqlen_q = seqlens.max() - return PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - qkv_format="thd", - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_q, - ) diff --git a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py deleted file mode 100644 index 8f765a0ff63..00000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from megatron.core.transformer.transformer_block import * - - -class Qwen2_5VisionTransformerBlock(TransformerBlock): - def _checkpointed_forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - context: Tensor, - context_mask: Tensor, - rotary_pos_emb: Tensor, - attention_bias: Tensor, - packed_seq_params: PackedSeqParams, - packed_seq_params_full: PackedSeqParams, - fullatt_block_indexes, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): - for index in range(start, end): - if index in fullatt_block_indexes: - packed_seq_params_now = packed_seq_params_full - else: - packed_seq_params_now = packed_seq_params - layer = self._get_layer(index) - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - inference_context=None, - packed_seq_params=packed_seq_params_now, - ) - return hidden_states, context - - return custom_forward - - def checkpoint_handler(forward_func): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: - return te_checkpoint( - forward_func, - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - else: - return tensor_parallel.checkpoint( - forward_func, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - - if self.config.recompute_method == "uniform": - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - layer_idx = 0 - while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states, context = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers) - ) - - layer_idx += self.config.recompute_num_layers - - elif self.config.recompute_method == "block": - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - recompute_skip_num_layers = 0 - for layer_idx in range(self.num_layers_per_pipeline_rank): - # Skip recomputation when input grad computation is not needed. - # Need to have at least one input tensor with gradient computation - # for re-enterant autograd engine. - if self.config.fp8 and not hidden_states.requires_grad: - recompute_skip_num_layers += 1 - if ( - layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers - ): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) - else: - hidden_states, context = custom(layer_idx, layer_idx + 1)( - hidden_states, attention_mask, context, context_mask, rotary_pos_emb - ) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def forward( - self, - hidden_states: Union[Tensor, WrappedTensor], - attention_mask: Optional[Tensor], - context: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - rotary_pos_emb: Optional[Tensor] = None, - rotary_pos_cos: Optional[Tensor] = None, - rotary_pos_sin: Optional[Tensor] = None, - attention_bias: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[Tensor] = None, - packed_seq_params_full: PackedSeqParams = None, - fullatt_block_indexes=None, - *, - inference_params: Optional[BaseInferenceContext] = None, - ): - """ - Perform the forward pass through the transformer block. - - This method handles the core computation of the transformer, including - self-attention, optional cross-attention, and feed-forward operations. - - Args: - hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] - where s is the sequence length, b is the batch size, and h is the hidden size. - Can be passed as a WrappedTensor during inference to avoid an obsolete - reference in the calling function. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention. - context_mask (Tensor, optional): Mask for cross-attention context - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable - to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. - Used as an alternative to apply attention mask for TE cuDNN attention. - inference_context (BaseInferenceContext, optional): Parameters for inference-time - optimizations. - packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence - processing. - - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Delete the obsolete reference to the initial input tensor if necessary - if isinstance(hidden_states, WrappedTensor): - hidden_states = hidden_states.unwrap() - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Update the inference parameters with the current batch size in case it is variable - if inference_context and not self.training: - inference_context.current_batch_size = hidden_states.size(1) - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), - # otherwise do nothing extra at the outer level - # if we are using other fp8 recipes, then the context manager enter&exit are free - # we can wrap fp8_context within the for loop over layers, so that we can fine-grained - # control which layer will be fp8 or bf16 - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() - - with rng_context, outer_fp8_context: - # Forward pass. - if self.config.recompute_granularity == "full" and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - packed_seq_params_full=packed_seq_params_full, - fullatt_block_indexes=fullatt_block_indexes, - ) - else: - for l_no, layer in enumerate(self.layers): - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() - ) - if l_no in fullatt_block_indexes: - packed_seq_params_now = packed_seq_params_full - else: - packed_seq_params_now = packed_seq_params - with self.offload_context, inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params_now, - sequence_len_offset=sequence_len_offset, - ) - - if ( - torch.is_grad_enabled() - and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None - ): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) - - # Final layer norm. - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - # TENorm produces a "viewed" tensor. This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - return hidden_states diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index d8c7b2cfa86..d244fcfa9a0 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -19,255 +19,43 @@ from enum import Enum from typing import Callable -import torch -import torch.nn as nn - -from .config_converter import ( - PretrainedConfig, - TransformerConfig, - hf_to_mcore_config_dense, - hf_to_mcore_config_dpskv3, - hf_to_mcore_config_llama4, - hf_to_mcore_config_mixtral, - hf_to_mcore_config_qwen2_5_vl, - hf_to_mcore_config_qwen2moe, - hf_to_mcore_config_qwen3moe, -) from .model_forward import gptmodel_forward_no_padding, model_forward_gen from .model_forward_fused import fused_forward_model_gen -from .model_initializer import ( - BaseModelInitializer, - DeepseekV3Model, - DenseModel, - MixtralModel, - Qwen2MoEModel, - Qwen3MoEModel, - Qwen25VLModel, -) -from .weight_converter import ( - McoreToHFWeightConverterDense, - McoreToHFWeightConverterDpskv3, - McoreToHFWeightConverterMixtral, - McoreToHFWeightConverterQwen2_5_VL, - McoreToHFWeightConverterQwen2Moe, - McoreToHFWeightConverterQwen3Moe, -) - -class SupportedModel(Enum): - LLAMA = "LlamaForCausalLM" # tested - QWEN2 = "Qwen2ForCausalLM" # tested - QWEN2_MOE = "Qwen2MoeForCausalLM" # pending - DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested - MIXTRAL = "MixtralForCausalLM" # tested - QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported - LLAMA4 = "Llama4ForConditionalGeneration" # not tested - QWEN3 = "Qwen3ForCausalLM" # tested - QWEN3_MOE = "Qwen3MoeForCausalLM" # tested - GLM4_MOE = "Glm4MoeForCausalLM" - QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification" +class SupportedVLM(Enum): + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" QWEN3_VL = "Qwen3VLForConditionalGeneration" - GPT_OSS = "GptOssForCausalLM" - - -# Registry for model configuration converters -MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { - SupportedModel.LLAMA: hf_to_mcore_config_dense, - SupportedModel.QWEN2: hf_to_mcore_config_dense, - SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, - SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3, - SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral, - SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, - SupportedModel.LLAMA4: hf_to_mcore_config_llama4, - SupportedModel.QWEN3: hf_to_mcore_config_dense, - SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, - SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense, -} - -# Registry for model initializers -MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = { - SupportedModel.LLAMA: DenseModel, - SupportedModel.QWEN2: DenseModel, - SupportedModel.QWEN2_MOE: Qwen2MoEModel, - SupportedModel.MIXTRAL: MixtralModel, - SupportedModel.DEEPSEEK_V3: DeepseekV3Model, - SupportedModel.QWEN2_5_VL: Qwen25VLModel, - SupportedModel.LLAMA4: DenseModel, - SupportedModel.QWEN3: DenseModel, - SupportedModel.QWEN3_MOE: Qwen3MoEModel, - SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel, -} - -# Registry for model forward functions -MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { - SupportedModel.LLAMA: model_forward_gen(), - SupportedModel.QWEN2: model_forward_gen(), - SupportedModel.QWEN2_MOE: model_forward_gen(), - SupportedModel.MIXTRAL: model_forward_gen(), - SupportedModel.DEEPSEEK_V3: model_forward_gen(), - SupportedModel.LLAMA4: model_forward_gen(), - SupportedModel.QWEN3: model_forward_gen(), - SupportedModel.QWEN3_MOE: model_forward_gen(), - SupportedModel.QWEN2_5_VL: model_forward_gen(True), - SupportedModel.QWEN3_MOE_VL: model_forward_gen(True), - SupportedModel.QWEN3_VL: model_forward_gen(True), - SupportedModel.DEEPSEEK_V3: model_forward_gen(), - SupportedModel.GLM4_MOE: model_forward_gen(), - SupportedModel.QWEN3_TOKEN_CLASSIFICATION: model_forward_gen(), - SupportedModel.GPT_OSS: model_forward_gen(), -} - -# Registry for model forward functions -MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = { - SupportedModel.LLAMA: gptmodel_forward_no_padding, - SupportedModel.QWEN2: gptmodel_forward_no_padding, - SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding, - SupportedModel.MIXTRAL: gptmodel_forward_no_padding, - SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, - SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding, - SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding, - SupportedModel.QWEN3_VL: gptmodel_forward_no_padding, - SupportedModel.LLAMA4: gptmodel_forward_no_padding, - SupportedModel.QWEN3: gptmodel_forward_no_padding, - SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding, - SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, - SupportedModel.GLM4_MOE: gptmodel_forward_no_padding, - SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, - SupportedModel.GPT_OSS: gptmodel_forward_no_padding, -} - -# Registry for model forward functions -MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { - SupportedModel.LLAMA: fused_forward_model_gen(), - SupportedModel.QWEN2: fused_forward_model_gen(), - SupportedModel.QWEN2_MOE: fused_forward_model_gen(), - SupportedModel.MIXTRAL: fused_forward_model_gen(), - SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(), - SupportedModel.QWEN2_5_VL: fused_forward_model_gen(True), - SupportedModel.QWEN3_MOE_VL: fused_forward_model_gen(True), - SupportedModel.QWEN3_VL: fused_forward_model_gen(True), - SupportedModel.LLAMA4: fused_forward_model_gen(), - SupportedModel.QWEN3: fused_forward_model_gen(), - SupportedModel.QWEN3_MOE: fused_forward_model_gen(), - SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(), - SupportedModel.GLM4_MOE: fused_forward_model_gen(), - SupportedModel.GPT_OSS: fused_forward_model_gen(), -} -# Registry for model weight converters -MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = { - SupportedModel.LLAMA: McoreToHFWeightConverterDense, - SupportedModel.QWEN2: McoreToHFWeightConverterDense, - SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, - SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral, - SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, - SupportedModel.QWEN3: McoreToHFWeightConverterDense, - SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, - SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, - SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense, -} - -def get_supported_model(model_type: str) -> SupportedModel: - try: - return SupportedModel(model_type) - except ValueError as err: - supported_models = [e.value for e in SupportedModel] - raise NotImplementedError( - f"Model Type: {model_type} not supported. Supported models: {supported_models}" - ) from err - - -def hf_to_mcore_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - """Convert huggingface PretrainedConfig to mcore TransformerConfig. - - Args: - hf_config: The huggingface PretrainedConfig. - dtype: The dtype of the model. - **override_transformer_config_kwargs: The kwargs to override the transformer config. - - Returns: - The mcore TransformerConfig. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) - - -def init_mcore_model( - tfconfig: TransformerConfig, - hf_config: PretrainedConfig, - pre_process: bool = True, - post_process: bool = None, - *, - share_embeddings_and_output_weights: bool = False, - value: bool = False, - **extra_kwargs, # may be used for vlm and moe -) -> nn.Module: - """ - Initialize a Mcore model. - - Args: - tfconfig: The transformer config. - hf_config: The HuggingFace config. - pre_process: Optional pre-processing function. - post_process: Optional post-processing function. - share_embeddings_and_output_weights: Whether to share embeddings and output weights. - value: Whether to use value. - **extra_kwargs: Additional keyword arguments. - - Returns: - The initialized model. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - initializer_cls = MODEL_INITIALIZER_REGISTRY[model] - initializer = initializer_cls(tfconfig, hf_config) - return initializer.initialize( - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - value=value, - **extra_kwargs, - ) - - -def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: +def get_mcore_forward_fn(hf_config) -> Callable: """ Get the forward function for given model architecture. """ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_REGISTRY[model] + if hf_config.architectures[0] in SupportedVLM: + return model_forward_gen(True) + else: + # default to language model + return model_forward_gen(False) -def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable: +def get_mcore_forward_no_padding_fn(hf_config) -> Callable: """ Get the forward function for given model architecture. """ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_NOPAD_REGISTRY[model] + return gptmodel_forward_no_padding -def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: +def get_mcore_forward_fused_fn(hf_config) -> Callable: """ Get the forward function for given model architecture. """ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_FUSED_REGISTRY[model] - - -def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: - """ - Get the weight converter for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - tfconfig = hf_to_mcore_config(hf_config, dtype) - return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) + if hf_config.architectures[0] in SupportedVLM: + return fused_forward_model_gen(True) + else: + # default to language model + return fused_forward_model_gen(False) diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py deleted file mode 100644 index 2a954b2417c..00000000000 --- a/verl/models/mcore/saver.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank( - tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 -): - """Calculate global rank with support for CP/EP parallelism""" - - # Get parallel sizes for each dimension - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - cp_size = mpu.get_context_parallel_world_size() - # ep_size = mpu.get_expert_model_parallel_world_size() - - # Verify total GPU count matches (must be consistent with parallel_state.py) - total_size = tp_size * dp_size * pp_size * cp_size - assert total_size == torch.distributed.get_world_size(), ( - f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" - ) - - # Core calculation logic (corresponds to RankGenerator order parameter) - # Assumes default order is "tp-cp-ep-dp-pp" - return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].decoder.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - if gpt_model_module.config.qk_layernorm: - _broadcast_tensor( - sync_layer.self_attention.q_layernorm.weight, - f"{layer_name}.self_attn.q_norm.weight", - src_pp_rank=src_pp_rank, - ) - _broadcast_tensor( - sync_layer.self_attention.k_layernorm.weight, - f"{layer_name}.self_attn.k_norm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - if gpt_model_module.config.add_qkv_bias: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.linear_fc1.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - if tie_word_embeddings: - print_rank_0("tie word embedding skip load lm_head...") - else: - print_rank_0("collecting lm_head...") - - if is_value_model: - lm_head_weight = None - if pp_rank == pp_size - 1: - lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) - _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict - - -def merge_megatron_ckpt_gptmodel_qwen_moe( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") - - -def merge_megatron_ckpt_gptmodel_qwen2_5_vl( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") - - -def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") - - -def merge_megatron_ckpt_gptmodel_mixtral( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/verl/models/mcore/weight_converter.py b/verl/models/mcore/weight_converter.py deleted file mode 100644 index 791513f32d1..00000000000 --- a/verl/models/mcore/weight_converter.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# online convert mcore weight to pure huggingface weight, no any fusion -# including format conversion and name mapping -# not including resharding -import torch -from megatron.core.transformer import TransformerConfig -from transformers import PretrainedConfig - - -class McoreToHFWeightConverterBase: - def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig): - self.hf_config = hf_config - self.mcore_config = mcore_config - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: - raise NotImplementedError - - -class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.self_attention.linear_proj.weight' - # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' - # 'decoder.layers.0.self_attention.linear_qkv.weight' - # 'decoder.layers.0.self_attention.linear_qkv.bias' - layer_number = name.split(".")[2] - convert_names = [] - if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: - param_type = name.split(".")[-1] - assert param_type == "bias" or param_type == "weight" - convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") - assert len(params) == 3 - elif "self_attention.linear_proj.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") - assert len(params) == 1 - elif "self_attention.linear_qkv.layer_norm_weight" in name: - convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight") - assert len(params) == 1 - elif "self_attention.q_layernorm.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight") - assert len(params) == 1 - elif "self_attention.k_layernorm.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' - # 'decoder.layers.0.mlp.linear_fc1.weight' - # 'decoder.layers.0.mlp.linear_fc2.weight' - layer_number = name.split(".")[2] - convert_names = [] - if "mlp.linear_fc1.weight" in name: - # split gate_proj and up_proj - convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") - assert len(params) == 2 - elif "mlp.linear_fc1.layer_norm_weight" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.linear_fc2.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "output_layer.weight": "lm_head.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - - if "self_attention" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - -class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.pre_mlp_layernorm.weight', - # 'decoder.layers.0.mlp.router.weight', - # 'decoder.layers.0.mlp.shared_experts.gate_weight', - # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight', - # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight' - # moe1 - # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', - # moe2 - # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") - assert len(params) == 1 - elif "shared_experts.gate_weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") - assert len(params) == 1 - elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") - assert len(params) == 2 - elif "shared_experts.linear_fc2.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") - assert len(params) == 1 - elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - -class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", - "language_model.decoder.final_layernorm.weight": "model.norm.weight", - "language_model.output_layer.weight": "lm_head.weight", - "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", - "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", - "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", - "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", - "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", - "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - - if "self_attention" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - model_type, _, _, layer_number = name.split(".")[:4] - - convert_names = [] - if model_type == "language_model": - name_map_after_layer = { - "self_attention.linear_qkv.bias": [ - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias", - ], - "self_attention.linear_qkv.weight": [ - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight", - ], - "self_attention.linear_proj.weight": "self_attn.o_proj.weight", - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - elif model_type == "vision_model": - name_map_after_layer = { - "self_attention.linear_proj.weight": "attn.proj.weight", - "self_attention.linear_proj.bias": "attn.proj.bias", - "self_attention.linear_qkv.layer_norm_weight": "norm1.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer, None) - if mapped_name is None: - assert "linear_qkv" in name_after_layer - assert len(params) == 3 - new_param = torch.cat(params, dim=0) - params = [new_param] - if "bias" in name_after_layer: - convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") - else: - convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") - else: - assert len(params) == 1 - convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") - else: - raise NotImplementedError(f"Unsupported model type: {model_type}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - model_type, _, _, layer_number = name.split(".")[:4] - - convert_names = [] - if model_type == "language_model": - name_map_after_layer = { - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.linear_fc2.bias": "mlp.down_proj.bias", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - - elif model_type == "vision_model": - name_map_after_layer = { - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.linear_fc2.bias": "mlp.down_proj.bias", - "mlp.linear_fc1.layer_norm_weight": "norm2.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"visual.blocks.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") - else: - raise NotImplementedError(f"Unsupported model type: {model_type}") - return convert_names, params - - -class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # mcore - # 'decoder.layers.0.input_layernorm.weight' - # 'decoder.layers.0.self_attention.linear_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_proj.weight' - # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight' - # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight' - # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_down_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_up_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight' - # hf - # 'model.layers.0.input_layernorm.weight' - # 'model.layers.0.self_attn.o_proj.weight' - # 'model.layers.0.self_attn.q_proj.weight' - # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight' - # 'model.layers.0.self_attn.kv_a_layernorm.weight' - # 'model.layers.0.self_attn.kv_b_proj.weight' - # 'model.layers.0.self_attn.q_a_proj.weight' - # 'model.layers.0.self_attn.q_b_proj.weight' - # 'model.layers.0.self_attn.q_a_layernorm.weight' - name_map_after_layer = { - "input_layernorm.weight": "input_layernorm.weight", - "self_attention.linear_proj.weight": "self_attn.o_proj.weight", - "self_attention.linear_q_proj.weight": "self_attn.q_proj.weight", - "self_attention.linear_kv_down_proj.weight": "self_attn.kv_a_proj_with_mqa.weight", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj.weight": "self_attn.kv_b_proj.weight", - "self_attention.linear_q_down_proj.weight": "self_attn.q_a_proj.weight", - "self_attention.linear_q_up_proj.weight": "self_attn.q_b_proj.weight", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - } - assert len(params) == 1 - convert_names = [] - layer_number = name.split(".")[2] - name_after_layer = name.split(f".{layer_number}.")[1] - convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # mcore dense - # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' - # 'decoder.layers.0.mlp.linear_fc2.weight' - # 'decoder.layers.0.mlp.linear_fc1.weight' - # --- - # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight' - # --- - # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight' - # hf dense - # 'model.layers.0.post_attention_layernorm.weight' - # 'model.layers.0.mlp.down_proj.weight' - # 'model.layers.0.mlp.gate_proj.weight' - # 'model.layers.0.mlp.up_proj.weight' - # 'model.layers.1.mlp.shared_experts.gate_proj.weight' - # 'model.layers.1.mlp.shared_experts.up_proj.weight' - # 'model.layers.1.mlp.shared_experts.down_proj.weight' - - # mcore moe - # 'decoder.layers.1.pre_mlp_layernorm.weight' - # 'decoder.layers.1.mlp.router.weight' - # 'decoder.layers.1.mlp.router.expert_bias' - # 'decoder.layers.1.mlp.experts.linear_fc1.weight0' - # --- - # 'decoder.layers.1.mlp.experts.linear_fc2.weight0' - # hf moe - # 'model.layers.1.post_attention_layernorm.weight' - # 'model.layers.1.mlp.gate.weight' - # 'model.layers.1.mlp.gate.e_score_correction_bias' - # 'model.layers.1.mlp.experts.0.gate_proj.weight' - # 'model.layers.1.mlp.experts.0.up_proj.weight' - # 'model.layers.1.mlp.experts.0.down_proj.weight' - - name_map_after_layer = { - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight", - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.shared_experts.linear_fc1.weight": [ - "mlp.shared_experts.gate_proj.weight", - "mlp.shared_experts.up_proj.weight", - ], - "pre_mlp_layernorm.weight": "post_attention_layernorm.weight", - "mlp.router.weight": "mlp.gate.weight", - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - } - convert_names = [] - layer_number = name.split(".")[2] - name_after_layer = name.split(f".{layer_number}.")[1] - if name_after_layer in name_map_after_layer: - mapped_name = name_map_after_layer[name_after_layer] - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - else: - if "mlp.experts.linear_fc1.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - return convert_names, params - - def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" - assert self.mcore_config.num_layers == 61, "only support 61 layers for now" - direct_name_mapping = { - "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", - "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", - "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", - "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params[0]] - assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" - # use proxy name to convert - proxy_name = name.replace("mtp.layers.0.transformer_layer", "decoder.layers.61") - if "self_attention" in proxy_name or "input_layernorm.weight" in proxy_name: - convert_names, params = self._convert_attention_param(proxy_name, params) - elif "mlp" in proxy_name: - convert_names, params = self._convert_mlp_param(proxy_name, params) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "output_layer.weight": "lm_head.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - if "mtp" in name: - return self._convert_mtp_param(name, params_one_group) - elif "self_attention" in name or "input_layernorm.weight" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - -class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # decoder.layers.0.mlp.router.weight - # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 - # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 - - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") - elif "mlp.experts.linear_fc1.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") - elif "mlp.experts.linear_fc2.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - -class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # qwen3 moe no share expert - - # 'decoder.layers.0.pre_mlp_layernorm.weight', - # 'decoder.layers.0.mlp.router.weight', - # moe1 - # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', - # moe2 - # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") - assert len(params) == 1 - elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params diff --git a/verl/models/qwen2/__init__.py b/verl/models/qwen2/__init__.py deleted file mode 100644 index 1ce90c5eb35..00000000000 --- a/verl/models/qwen2/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/qwen2/megatron/__init__.py b/verl/models/qwen2/megatron/__init__.py deleted file mode 100644 index 57e33ee9e90..00000000000 --- a/verl/models/qwen2/megatron/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .modeling_qwen2_megatron import ( - ParallelQwen2ForCausalLM, - # rmpad with megatron - ParallelQwen2ForCausalLMRmPad, - # rmpad with megatron and pipeline parallelism - ParallelQwen2ForCausalLMRmPadPP, - ParallelQwen2ForValueRmPad, - ParallelQwen2ForValueRmPadPP, - # original model with megatron - ParallelQwen2Model, -) - -__all__ = [ - "ParallelQwen2ForCausalLM", - "ParallelQwen2ForCausalLMRmPad", - "ParallelQwen2ForCausalLMRmPadPP", - "ParallelQwen2ForValueRmPad", - "ParallelQwen2ForValueRmPadPP", - "ParallelQwen2Model", -] diff --git a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py deleted file mode 100644 index 1ce90c5eb35..00000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py deleted file mode 100644 index 3168635c7fe..00000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def fetch_params(module): - for param in module.parameters(): - torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _fetch_tensor(tensor, name) -> torch.Tensor: - """fetch tensor""" - nonlocal state_dict - if tensor is not None: - tensor = tensor.data.copy_(state_dict[name], non_blocking=True) - - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """fetch gate_up tensor in tp shards""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if gate_name in state_dict and up_name in state_dict: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """fetch tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - layer_list = [] - if vpp_size is not None: - for vpp_rank in range(vpp_size): - num_layer_vpp_chunk = num_layer_per_pp // vpp_size - num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - else: - num_layer_this_model = num_layer_per_pp - offset = pp_rank * num_layer_per_pp - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - - for layer in layer_list: - print(f"{torch.distributed.get_rank()} loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - print( - f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, " - f"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" - ) - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _fetch_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _fetch_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _fetch_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _fetch_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - if tie_word_embeddings: - print_rank_0("tie_word_embeddings skip load lm_head") - else: - print_rank_0("loading lm_head...") - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _fetch_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - - else: - _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py deleted file mode 100644 index 770e3653366..00000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == 0: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == 0: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=0, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - if tie_word_embeddings: - print_rank_0("tie_word_embeddings skip load lm_head") - else: - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py deleted file mode 100644 index 737f73b4c61..00000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): - """given TP,DP,PP rank to get the global rank.""" - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) - # We only support TP-DP-PP grouping, for correctness when resharding - return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - if tie_word_embeddings: - print_rank_0("tie word embedding skip load lm_head...") - else: - print_rank_0("collecting lm_head...") - - if is_value_model: - _broadcast_tensor( - gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, - "reward_head.weight", - src_pp_rank=pp_size - 1, - ) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict diff --git a/verl/models/qwen2/megatron/layers/__init__.py b/verl/models/qwen2/megatron/layers/__init__.py deleted file mode 100644 index 263ea596fa7..00000000000 --- a/verl/models/qwen2/megatron/layers/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .parallel_attention import ParallelQwen2Attention -from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad -from .parallel_mlp import ParallelQwen2MLP -from .parallel_rmsnorm import ParallelQwen2RMSNorm - -__all__ = [ - "ParallelQwen2Attention", - "ParallelQwen2DecoderLayer", - "ParallelQwen2DecoderLayerRmPad", - "ParallelQwen2MLP", - "ParallelQwen2RMSNorm", -] diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py deleted file mode 100644 index 4e4f5910151..00000000000 --- a/verl/models/qwen2/megatron/layers/parallel_attention.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch.nn.functional as F -from einops import rearrange -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 - -import torch -from flash_attn.layers.rotary import apply_rotary_emb -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers import Qwen2Config - -from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): - """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): - """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class ParallelQwen2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config = config - self.megatron_config = megatron_config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # assign values after tp - tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) - assert self.num_key_value_heads % tp_size == 0, ( - f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" - f"{self.num_key_value_heads}, tp_size={tp_size}" - ) - - self.num_heads_per_tp = self.num_heads // tp_size - self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size - self.hidden_size_per_tp = self.hidden_size // tp_size - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - - # [self.q_size, self.k_size, self.v_size] - self.qkv_proj = QKVParallelLinear( - input_size=self.hidden_size, - num_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - # bias=config.attention_bias, - bias=True, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - self.q_size = self.num_heads_per_tp * self.head_dim - self.k_size = self.num_key_value_heads_per_tp * self.head_dim - self.v_size = self.num_key_value_heads_per_tp * self.head_dim - - self.o_proj = tensor_parallel.RowParallelLinear( - input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, - # bias=config.attention_bias, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self._init_rope() - - def _init_rope(self): - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) - - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) - attn_output = self.o_proj(attn_output)[0] - return attn_output - - -""" -Remove padding Attention -- Using Flash-attn 2 -- Compatible with sequence parallel -""" - - -def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): - batch_size = position_ids.shape[0] - - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) - k = pad_input(k, indices, batch_size, sequence_length) - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) - k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) - - return q_embed, k_embed - - -# use flash-attn rotary embeddings with rmpad -# cos/sin shoudl be: (seq_length, rotary_dim / 2) -def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - return q_embed, k_embed - - -class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None, - ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel - - if self.megatron_config.sequence_parallel: - total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() - - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split( - [self.q_size, self.k_size, self.v_size], dim=-1 - ) # (total_nnz, 1, hidden_size) - - if self.megatron_config.sequence_parallel: - sequence_parallel_pad = total_nnz - cu_seqlens[-1] - total_nnz = cu_seqlens[-1] # total_nnz before sp padding - query_states = query_states[:total_nnz] - key_states = key_states[:total_nnz] - value_states = value_states[:total_nnz] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - - cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch - ) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, - # position_ids, indices, - - # It is recommended to use dropout with FA according to the docs - # when training. - dropout_rate = 0.0 # if not self.training else self.attn_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Qwen2RMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_in_batch, - max_seqlen_k=max_seqlen_in_batch, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - - attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() - - # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled - # Here we need to repad - if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) - - attn_output_unpad = self.o_proj(attn_output_unpad)[0] - return attn_output_unpad diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py deleted file mode 100644 index 3c8a2a6ee94..00000000000 --- a/verl/models/qwen2/megatron/layers/parallel_decoder.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import Qwen2Config - -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad -from .parallel_mlp import ParallelQwen2MLP -from .parallel_rmsnorm import ParallelQwen2RMSNorm - - -class ParallelQwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) - - self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Note: sequence parallel is hidden inside ColumnParallelLinear - # reduce scatter is hidden inside RowParallelLinear - - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # TODO: add sequence parallel operator all_gather here - - hidden_states = self.mlp(hidden_states) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs - - -class ParallelQwen2DecoderLayerRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) - - self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states # (total_nnz // sp, 1, hidden_size) - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) - # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - # shape changes same as attn - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs diff --git a/verl/models/qwen2/megatron/layers/parallel_linear.py b/verl/models/qwen2/megatron/layers/parallel_linear.py deleted file mode 100644 index e6d4a09f430..00000000000 --- a/verl/models/qwen2/megatron/layers/parallel_linear.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py - - -from megatron.core import tensor_parallel - - -class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.q_output_size = num_heads * head_dim - self.kv_output_size = num_key_value_heads * head_dim - self.head_dim = head_dim - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - input_size = self.input_size - output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - - super().__init__( - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.output_size = gate_ouput_size + up_output_size - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - super().__init__( - input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) diff --git a/verl/models/qwen2/megatron/layers/parallel_mlp.py b/verl/models/qwen2/megatron/layers/parallel_mlp.py deleted file mode 100644 index 672908a21ae..00000000000 --- a/verl/models/qwen2/megatron/layers/parallel_mlp.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers.activations import ACT2FN - -from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class ParallelQwen2MLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - - tp_size = mpu.get_tensor_model_parallel_world_size() - - self.gate_up_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - gate_ouput_size=self.intermediate_size, - up_output_size=self.intermediate_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - self.gate_size = self.intermediate_size // tp_size - - self.down_proj = tensor_parallel.RowParallelLinear( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - gate_up = self.gate_up_proj(x)[0] - gate, up = gate_up.split(self.gate_size, dim=-1) - return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py deleted file mode 100644 index 2f4c90dd44e..00000000000 --- a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numbers - -import torch -from apex.normalization.fused_layer_norm import fused_rms_norm_affine -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import Qwen2Config - -from verl.utils.megatron import sequence_parallel as sp_utils - - -class ParallelQwen2RMSNorm(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - if isinstance(config.hidden_size, numbers.Integral): - normalized_shape = (config.hidden_size,) - self.normalized_shape = torch.Size(normalized_shape) - self.weight = nn.Parameter(torch.ones(self.normalized_shape)) - self.variance_epsilon = config.rms_norm_eps - - if megatron_config.sequence_parallel: - sp_utils.mark_parameter_as_sequence_parallel(self.weight) - - def forward(self, hidden_states): - return fused_rms_norm_affine( - input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True, - ) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py deleted file mode 100644 index b3512f8afa5..00000000000 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ /dev/null @@ -1,737 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2 model.""" - -from typing import Optional - -import torch -import torch.utils.checkpoint -from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel -from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.qwen2.configuration_qwen2 import Qwen2Config -from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast - -from verl.utils.device import get_device_name -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm - -""" -TODO: -1. Add weight initialization. Here we need to be careful on TP weight init. -2. Add sequence parallel -3. Load checkpoint from Qwen2 pretrained checkpoint -""" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class ParallelQwen2Model(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (batch_size, seq_length) - attention_mask: attention_mask. shape (batch_size, seq_length) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - batch_size, seq_length = input_ids.shape - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLM(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.model = ParallelQwen2Model(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401, E402 - - -class ParallelQwen2ModelRmPad(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - self.megatron_config = megatron_config - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLMRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - self._init_head(config) - - def _init_head(self, config: Qwen2Config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - logits = self.lm_head(hidden_states)[0] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - batch_size, sequence_length = input_ids.shape - - # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids = sp_utils.pad_to_sequence_parallel(input_ids) - - input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = outputs - - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids, attention_mask, position_ids) - output.logits = torch.squeeze(output.logits, dim=-1) - return output - - -""" -Support pipeline parallelism -""" - - -class ParallelQwen2ModelRmPadPP(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - This model definition supports pipeline parallelism. To support pp and vpp, - - This model only contains layer in this pp stage and vpp chunk - - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - self.megatron_config = megatron_config - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - else: - self.embed_tokens = None - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = megatron_config.pipeline_model_parallel_size - self.num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = megatron_config.virtual_pipeline_model_parallel_size - vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() - - if vpp_size is not None: - self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size - self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) - else: - self.num_layer_this_model = self.num_layer_per_pp - offset = pp_rank * self.num_layer_per_pp - - self.layers = nn.ModuleList() - for i in range(self.num_layer_this_model): - layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) - self.layers.add_module(f"{i}", layer) - - if post_process: - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - else: - self.norm = None - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron - # so need to deal with it by handle here: - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - else: - # self.hidden_states should be passed by Megatron - hidden_states = self.input_tensor - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - if self.post_process: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLMRmPadPP(nn.Module): - def __init__( - self, - config: Qwen2Config, - megatron_config: ModelParallelConfig, - pre_process, - post_process, - share_embeddings_and_output_weights, - ): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process - ) - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - if post_process: - self._init_head(config) - if pre_process or post_process: - self.setup_embeddings_and_output_layer() - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - assert len(input_tensor) == 1 - self.model.set_input_tensor(input_tensor[0]) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, - **column_kwargs, - ) - - def setup_embeddings_and_output_layer(self) -> None: - """Sets up embedding layer in first stage and output layer in last stage. - - This function initializes word embeddings in the final stage when we are - using pipeline parallelism and sharing word embeddings, and sets up param - attributes on the embedding and output layers. - """ - # Set `is_embedding_or_output_parameter` attribute. - if self.pre_process: - self.model.embed_tokens.weight.is_embedding_or_output_parameter = True - if self.post_process and self.lm_head.weight is not None: - self.lm_head.weight.is_embedding_or_output_parameter = True - - if not self.share_embeddings_and_output_weights: - return - - if parallel_state.get_pipeline_model_parallel_world_size() == 1: - # Zero out wgrad if sharing embeddings between two layers on same - # pipeline stage to make sure grad accumulation into main_grad is - # correct and does not include garbage values (e.g., from torch.empty). - self.shared_embedding_or_output_weight().zero_out_wgrad = True - return - - if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: - self.shared_embedding_or_output_weight().shared_embedding = True - - if self.post_process and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.lm_head.weight.data.fill_(0) - self.lm_head.weight.shared = True - self.lm_head.weight.shared_embedding = True - - if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): - weight = self.shared_embedding_or_output_weight() - weight.data = weight.data.to(get_device_name()) - torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) - - def shared_embedding_or_output_weight(self) -> torch.Tensor: - if self.pre_process: - return self.model.embed_tokens.weight - elif self.post_process: - return self.lm_head.weight - return None - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = ' - # f'{self.config.vocab_size}') # [4, 32, 4096] - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits = self.lm_head(hidden_states, weight=output_weight)[0] - # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - return logits - - def forward( - self, - # original input - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. - # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model - batch_size, sequence_length = input_ids.shape - # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) - - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - if self.post_process: - hidden_states = outputs - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - else: - return outputs - - -class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - if self.post_process: - output.logits = torch.squeeze(output.logits, dim=-1) - return output - else: - return output diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py deleted file mode 100644 index 0904f14fad4..00000000000 --- a/verl/models/weight_loader_registry.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def get_weight_loader(arch: str): - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - - _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { - "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, - "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, - } - - if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: - return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] - raise ValueError( - f"Model architectures {arch} loader are not supported for now. Supported architectures: " - f"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" - ) - - -def get_weight_saver(arch: str): - from verl.models.mcore.saver import ( - merge_megatron_ckpt_gptmodel, - merge_megatron_ckpt_gptmodel_dpskv3, - merge_megatron_ckpt_gptmodel_mixtral, - merge_megatron_ckpt_gptmodel_qwen2_5_vl, - merge_megatron_ckpt_gptmodel_qwen_moe, - ) - - _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { - "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, - "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, - "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, - "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, - "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, - "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, - "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, - "Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel, - "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, - } - if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: - return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] - raise ValueError( - f"Model architectures {arch} saver are not supported for now. Supported architectures: " - f"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" - ) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 6c906f2ac68..48ae14e781e 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -52,7 +52,6 @@ actor_rollout_ref: recompute_num_layers: null attention_backend: flash override_mcore_model_config: {} - use_mbridge: false vanilla_mbridge: true use_remove_padding: true forward_only: false @@ -187,7 +186,6 @@ actor_rollout_ref: override_ddp_config: {} override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} override_mcore_model_config: {} - use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} forward_only: true @@ -433,7 +431,6 @@ critic: recompute_num_layers: null attention_backend: flash override_mcore_model_config: {} - use_mbridge: false vanilla_mbridge: true use_remove_padding: true forward_only: false @@ -566,7 +563,6 @@ reward_model: dist_checkpointing_prefix: '' seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} - use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} dtype: bfloat16 diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index 84601f5a3f5..4b3520da645 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -74,8 +74,6 @@ override_transformer_config: override_mcore_model_config: {} -# oc.select: default val for ref.megatron.use_mbridge -use_mbridge: False # oc.select: default val for ref.megatron.vanilla_mbridge vanilla_mbridge: True diff --git a/verl/trainer/config/ref/megatron_ref.yaml b/verl/trainer/config/ref/megatron_ref.yaml index ca1fbb3c073..07c590c72de 100644 --- a/verl/trainer/config/ref/megatron_ref.yaml +++ b/verl/trainer/config/ref/megatron_ref.yaml @@ -15,7 +15,6 @@ strategy: megatron megatron: seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} - use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} diff --git a/verl/trainer/config/reward_model/megatron_reward_model.yaml b/verl/trainer/config/reward_model/megatron_reward_model.yaml index ea585075e57..65f742de52b 100644 --- a/verl/trainer/config/reward_model/megatron_reward_model.yaml +++ b/verl/trainer/config/reward_model/megatron_reward_model.yaml @@ -61,8 +61,6 @@ megatron: # Any overrides to transformer config override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} - # Whether to use mbridge for faster comms - use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} # Whether to use mbridge instead of Megatron-Bridge vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index 26a5f7b23ab..602e834986a 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -27,7 +27,6 @@ from megatron.core.transformer.enums import AttnBackend from transformers import GenerationConfig -from verl.models.weight_loader_registry import get_weight_saver from verl.utils.device import get_device_name, get_torch_device from verl.utils.fs import is_non_local, local_mkdir_safe from verl.utils.logger import log_with_rank @@ -154,10 +153,6 @@ def __init__( ) self.use_hf_checkpoint = not self.use_dist_checkpointing - self.weight_saver = None - if self.bridge is None: - self.weight_saver = get_weight_saver(self.arch) - def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False): """collect rng state across data parallel ranks""" rng_state = { @@ -568,65 +563,13 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if self.should_save_hf_model and not self.use_hf_checkpoint: # wait for everyone to dump to local - if self.bridge is not None: - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - if self.vanilla_bridge: - self.bridge.save_weights( - self.model, hf_model_ckpt_path, distributed_filesystem=True, memory_efficient=True - ) - else: - self.bridge.save_hf_weights(self.model, hf_model_ckpt_path) - else: - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + if self.vanilla_bridge: + self.bridge.save_weights( + self.model, hf_model_ckpt_path, distributed_filesystem=True, memory_efficient=True ) - - torch.distributed.barrier() - if self.rank == 0: - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings - - from accelerate import init_empty_weights - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification - - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - - if hdfs_path is not None: - log_with_rank( - f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True - ) - from verl.utils import hdfs_io - - hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) - log_with_rank( - f"HDFS checkpoint uploaded to {hdfs_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) + else: + self.bridge.save_hf_weights(self.model, hf_model_ckpt_path) def finalize_save_fn(): # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided diff --git a/verl/utils/megatron_peft_utils.py b/verl/utils/megatron_peft_utils.py index a18776c3413..3f8bee5ae09 100644 --- a/verl/utils/megatron_peft_utils.py +++ b/verl/utils/megatron_peft_utils.py @@ -59,7 +59,7 @@ def get_adapter_state_dict(model): Returns: Dict of adapter parameter names to tensors """ - from verl.utils.megatron_utils import unwrap_model + from megatron.core.utils import unwrap_model # Unwrap model from DDP/Float16Module unwrapped = unwrap_model(model) @@ -138,8 +138,7 @@ def load_adapter_checkpoint( strict: Whether to strictly enforce parameter name matching """ from megatron.core import mpu - - from verl.utils.megatron_utils import unwrap_model + from megatron.core.utils import unwrap_model # Get rank-specific path rank_path = _get_rank_checkpoint_path(checkpoint_path) @@ -192,7 +191,7 @@ def count_adapter_parameters(model): Returns: Tuple of (adapter_params, total_params, percentage) """ - from verl.utils.megatron_utils import unwrap_model + from megatron.core.utils import unwrap_model unwrapped = unwrap_model(model) if isinstance(unwrapped, list): diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index c319910d855..95284330e53 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -19,31 +19,22 @@ import gc import inspect import os -import warnings from dataclasses import dataclass from typing import Any import torch -import torch.nn.functional as F -from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel +from megatron.core import mpu, tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.enums import ModelType from megatron.core.optimizer import ChainedOptimizer from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import Float16Module -from megatron.core.utils import get_attr_wrapped_model +from megatron.core.utils import get_model_config from transformers import PretrainedConfig -import verl.utils.megatron.tensor_parallel as tp_utils from verl.utils.device import get_device_id, get_device_name, get_torch_device from verl.utils.fs import local_mkdir_safe -from verl.utils.model import normalize_model_name -from verl.utils.torch_dtypes import PrecisionType - - -def get_model_config(model): - return get_attr_wrapped_model(model, "config", allow_none=False) def get_model( @@ -184,221 +175,95 @@ def make_megatron_module( if override_model_config is None: override_model_config = {} - if bridge is not None: - if provider is None: - from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - - value_model_hook = make_value_model - else: - from verl.models.mcore.bridge import freeze_moe_router, make_value_model - - hidden_size = ( - hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size - ) - value_model_hook = make_value_model(hidden_size, provider.sequence_parallel) - - post_model_creation_callbacks = [] - if wrap_config.is_value_model: - post_model_creation_callbacks.append(value_model_hook) - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - if provider is not None: - # When using PEFT with Megatron-Bridge, we must apply PEFT transformation - # BEFORE wrapping the model in DDP. This is required because: - # 1. PEFT freezes base model parameters (requires_grad=False) - # 2. DDP must be aware of which parameters are trainable when building gradient buckets - # 3. The distributed optimizer must only track trainable (adapter) parameters - # See Megatron-Bridge docs: training/peft.md - - # Register PEFT transformation as pre-wrap hook if peft_cls is specified - # This must happen BEFORE DDP wrapping to avoid KeyError with frozen parameters - if peft_cls is not None: - from verl.utils.megatron_peft_utils import load_adapter_checkpoint, print_adapter_info - - def peft_pre_wrap_hook(model): - """Pre-wrap hook that applies PEFT transformation.""" - # Apply PEFT transformation - this will freeze base model and add adapters - # The PEFT callable handles both freezing and transformation - transformed_model = peft_cls(model, training=True) - - # Set parameters to save (adapter-only checkpointing) - peft_cls.set_params_to_save(transformed_model) - - # Load adapter weights if adapter_path is specified - adapter_path = getattr(peft_config, "adapter_path", None) - if adapter_path is not None and adapter_path: - print(f"Loading adapter weights from: {adapter_path}") - load_adapter_checkpoint(transformed_model, adapter_path) - - # Print PEFT statistics - if torch.distributed.get_rank() == 0: - print_adapter_info(transformed_model) - - return transformed_model - - provider.register_pre_wrap_hook(peft_pre_wrap_hook) - - # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks - for callback in post_model_creation_callbacks: - provider.register_pre_wrap_hook(callback) - - # Create DDP config if needed - ddp_config = None - if wrap_config.wrap_with_ddp: - from megatron.bridge.training.config import DistributedDataParallelConfig - - ddp_config_dict = { - "use_distributed_optimizer": wrap_config.use_distributed_optimizer, - } - # Apply any DDP config overrides - if override_ddp_config is not None: - ddp_config_dict.update(override_ddp_config) - - ddp_config = DistributedDataParallelConfig(**ddp_config_dict) - ddp_config.finalize() - - # Now call provide_distributed_model with all hooks registered - # Hooks will be applied automatically before DDP wrapping - model = provider.provide_distributed_model( - wrap_with_ddp=wrap_config.wrap_with_ddp, - ddp_config=ddp_config, - ) + if provider is None: + from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - # Extract TransformerConfig from the created model - tf_config = get_model_config(model[0] if isinstance(model, list) else model) - else: - model = bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, - wrap_with_ddp=wrap_config.wrap_with_ddp, - fp16=tf_config.fp16, - bf16=tf_config.bf16, - ddp_config=override_ddp_config, - ) + value_model_hook = make_value_model else: - - def megatron_model_provider(pre_process, post_process, vp_stage=None): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - tf_config, - hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=wrap_config.share_embeddings_and_output_weights, - value=wrap_config.is_value_model, - freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), - vp_stage=vp_stage, - ) - parallel_model.to(get_device_name()) - return parallel_model - - model = get_model( - megatron_model_provider, + from verl.models.mcore.bridge import freeze_moe_router, make_value_model + + hidden_size = hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size + value_model_hook = make_value_model(hidden_size, provider.sequence_parallel) + + post_model_creation_callbacks = [] + if wrap_config.is_value_model: + post_model_creation_callbacks.append(value_model_hook) + if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + post_model_creation_callbacks.append(freeze_moe_router) + if provider is not None: + # When using PEFT with Megatron-Bridge, we must apply PEFT transformation + # BEFORE wrapping the model in DDP. This is required because: + # 1. PEFT freezes base model parameters (requires_grad=False) + # 2. DDP must be aware of which parameters are trainable when building gradient buckets + # 3. The distributed optimizer must only track trainable (adapter) parameters + # See Megatron-Bridge docs: training/peft.md + + # Register PEFT transformation as pre-wrap hook if peft_cls is specified + # This must happen BEFORE DDP wrapping to avoid KeyError with frozen parameters + if peft_cls is not None: + from verl.utils.megatron_peft_utils import load_adapter_checkpoint, print_adapter_info + + def peft_pre_wrap_hook(model): + """Pre-wrap hook that applies PEFT transformation.""" + # Apply PEFT transformation - this will freeze base model and add adapters + # The PEFT callable handles both freezing and transformation + transformed_model = peft_cls(model, training=True) + + # Set parameters to save (adapter-only checkpointing) + peft_cls.set_params_to_save(transformed_model) + + # Load adapter weights if adapter_path is specified + adapter_path = getattr(peft_config, "adapter_path", None) + if adapter_path is not None and adapter_path: + print(f"Loading adapter weights from: {adapter_path}") + load_adapter_checkpoint(transformed_model, adapter_path) + + # Print PEFT statistics + if torch.distributed.get_rank() == 0: + print_adapter_info(transformed_model) + + return transformed_model + + provider.register_pre_wrap_hook(peft_pre_wrap_hook) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + for callback in post_model_creation_callbacks: + provider.register_pre_wrap_hook(callback) + + # Create DDP config if needed + ddp_config = None + if wrap_config.wrap_with_ddp: + from megatron.bridge.training.config import DistributedDataParallelConfig + + ddp_config_dict = { + "use_distributed_optimizer": wrap_config.use_distributed_optimizer, + } + # Apply any DDP config overrides + if override_ddp_config is not None: + ddp_config_dict.update(override_ddp_config) + + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) + ddp_config.finalize() + + # Now call provide_distributed_model with all hooks registered + # Hooks will be applied automatically before DDP wrapping + model = provider.provide_distributed_model( wrap_with_ddp=wrap_config.wrap_with_ddp, - use_distributed_optimizer=wrap_config.use_distributed_optimizer, - override_ddp_config=override_ddp_config, + ddp_config=ddp_config, ) - return model, tf_config + # Extract TransformerConfig from the created model + tf_config = get_model_config(model[0] if isinstance(model, list) else model) + else: + model = bridge.get_model( + post_model_creation_callbacks=post_model_creation_callbacks, + wrap_with_ddp=wrap_config.wrap_with_ddp, + fp16=tf_config.fp16, + bf16=tf_config.bf16, + ddp_config=override_ddp_config, + ) -ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) - - -def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): - return_list = True - if not isinstance(model, list): - model = [model] - return_list = False - unwrapped_model = [] - for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module - unwrapped_model.append(model_module) - if not return_list: - return unwrapped_model[0] - return unwrapped_model - - -def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: - """[Deprecated] convert config - - Args: - hf_config (PretrainedConfig): _description_ - megatron_config (_type_): _description_ - - Returns: - TransformerConfig: _description_ - """ - - warnings.warn("[deprecated] use config converter for more model support", stacklevel=2) - print(f"megatron config {megatron_config}") - dt = PrecisionType.to_dtype(megatron_config.params_dtype) - print(f"pipeline_dtype=megatron_config {dt}") - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) - overlap_p2p_comm = ( - mpu.get_virtual_pipeline_model_parallel_world_size() is not None - and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 - ) - batch_p2p_comm = False - transformer_config = TransformerConfig( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - ffn_hidden_size=hf_config.intermediate_size, - # max_position_embeddings=hf_config.max_position_embeddings, - activation_func=F.silu, - normalization="RMSNorm", - # rotary_percent=False, # default, - gated_linear_unit=True, # for llama - use_cpu_initialization=True, - apply_residual_connection_post_layernorm=False, # check what's this mean - add_bias_linear=False, - tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), - pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), - virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), - context_parallel_size=mpu.get_context_parallel_world_size(), - overlap_p2p_comm=overlap_p2p_comm, - batch_p2p_comm=batch_p2p_comm, - pipeline_dtype=dt, - params_dtype=dt, - sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1, - variable_seq_lengths=True, - masked_softmax_fusion=True, - moe_token_dispatcher_type="alltoall", - attention_dropout=hf_config.attention_dropout, - hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0), - add_qkv_bias=qkv_bias, - bf16=dt is torch.bfloat16, - ) - - return transformer_config - - -def mcore_model_parallel_config( - sequence_parallel: bool, - params_dtype: torch.dtype, -) -> ModelParallelConfig: - # WARNING: Code should not reach this point. This function is deprecated and will be removed. - # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead. - warnings.warn( - "Code should not reach this point. This function is deprecated and will be removed. Please use " - "hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", - DeprecationWarning, - stacklevel=2, - ) - return ModelParallelConfig( - tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), - pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), - virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), - context_parallel_size=mpu.get_context_parallel_world_size(), - sequence_parallel=sequence_parallel, - params_dtype=params_dtype, - pipeline_dtype=params_dtype, - bf16=True, - fp16=False, - timers=None, - ) + return model, tf_config @torch.no_grad() @@ -609,581 +474,6 @@ def get_transformer_config_checkpoint_path(checkpoint_path): return os.path.join(checkpoint_path, "transformer_config.json") -def convert_megatron_model_to_transformers_model( - name, - param, - config: PretrainedConfig, - tp_size: int, - num_query_groups: int, - convert_qkv_gate_up_by_trunk_concat=False, -): - """Convert megatron model to transformers model.""" - new_params = {} - - def convert_qkv_shard(full_tensor, q_name, k_name, v_name): - nonlocal config - nonlocal tp_size - nonlocal num_query_groups - - q_shard_list = [] - k_shard_list = [] - v_shard_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_shard_list.append(q_part) - k_shard_list.append(k_part) - v_shard_list.append(v_part) - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_shard_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_shard_list.append(k_part) - v_shard_list.append(v_part) - - new_params[q_name] = torch.cat(q_shard_list, dim=0) - new_params[k_name] = torch.cat(k_shard_list, dim=0) - new_params[v_name] = torch.cat(v_shard_list, dim=0) - - def convert_gate_up_shard(full_tensor, gate_name, up_name): - nonlocal config - nonlocal tp_size - - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - new_params[gate_name] = torch.cat(gate_weight_list, dim=0) - new_params[up_name] = torch.cat(up_weight_list, dim=0) - - if name == "embedding.word_embeddings.weight": - new_params["model.embed_tokens.weight"] = param - elif "self_attention" in name: - splitted_name = name.split(".") - layer_number = splitted_name[2] - component = splitted_name[4] - param_type = splitted_name[5] - if component == "linear_proj": - new_params[f"model.layers.{layer_number}.self_attn.o_proj.weight"] = param - elif component == "linear_qkv" and not isinstance(param, list): - if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = param - else: - if convert_qkv_gate_up_by_trunk_concat: - convert_qkv_shard( - param, - f"model.layers.{layer_number}.self_attn.q_proj.{param_type}", - f"model.layers.{layer_number}.self_attn.k_proj.{param_type}", - f"model.layers.{layer_number}.self_attn.v_proj.{param_type}", - ) - else: - new_params[f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}"] = param - elif component == "q_layernorm" or component == "k_layernorm": - hf_component = component.replace("layer", "") - new_params[f"model.layers.{layer_number}.self_attn.{hf_component}.weight"] = param - else: - assert isinstance(param, list) and len(param) == 3 - assert param_type == "weight" or param_type == "bias" - new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = param[0] - new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = param[1] - new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = param[2] - elif "mlp" in name: - splitted_name = name.split(".") - layer_number = splitted_name[2] - component = splitted_name[4] - param_type = splitted_name[5] - if component == "linear_fc1" and not isinstance(param, list): - if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.post_attention_layernorm.weight"] = param - elif param_type == "weight": - if convert_qkv_gate_up_by_trunk_concat: - convert_gate_up_shard( - param, - f"model.layers.{layer_number}.mlp.gate_proj.weight", - f"model.layers.{layer_number}.mlp.up_proj.weight", - ) - else: - new_params[f"model.layers.{layer_number}.mlp.gate_up_proj.weight"] = param - elif component == "linear_fc1" and isinstance(param, list): - assert len(param) == 2 - assert param_type == "weight" or param_type == "bias" - new_params[f"model.layers.{layer_number}.mlp.gate_proj.weight"] = param[0] - new_params[f"model.layers.{layer_number}.mlp.up_proj.weight"] = param[1] - elif component == "linear_fc2": - new_params[f"model.layers.{layer_number}.mlp.down_proj.weight"] = param - elif name == "decoder.final_layernorm.weight": - new_params["model.norm.weight"] = param - elif name == "output_layer.weight": - new_params["lm_head.weight"] = param - else: - raise ValueError(f"Unknown param name: {name}") - return new_params.keys(), new_params.values() - - -def broadcast_from_megatron_pp(tensor: torch.Tensor): - # tensor is not None only in one of the pp ranks - if tensor is not None: - shape = tensor.shape - dtype = tensor.dtype - tensor_parallel = getattr(tensor, "tensor_model_parallel", None) - partition_dim = getattr(tensor, "partition_dim", None) - tensor_spec = (shape, dtype, tensor_parallel, partition_dim) - else: - tensor_spec = None - tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object( - object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() - ) - # find the src rank - target_tensor_spec = None - src_rank = None - for rank, tensor_spec in enumerate(tensor_spec_output): - if tensor_spec is not None: - if target_tensor_spec is None: - target_tensor_spec = tensor_spec - else: - raise ValueError("A tensor exists on two pp ranks") - src_rank = rank - assert target_tensor_spec is not None - if tensor is None: - tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id()) - if target_tensor_spec[2] is not None: - tensor.tensor_model_parallel = target_tensor_spec[2] - if target_tensor_spec[3] is not None: - tensor.partition_dim = target_tensor_spec[3] - - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) - torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group()) - return tensor - - -def broadcast_str_from_megatron_pp(obj: Any): - obj_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group()) - - src_rank = None - target_obj = None - for rank, item in enumerate(obj_output): - if item is not None: - if target_obj is not None: - raise ValueError("An object exists on two pp ranks") - target_obj = item - src_rank = rank - - assert target_obj is not None, "No valid object found to broadcast." - - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) - - obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) - obj_output[0] = target_obj - torch.distributed.broadcast_object_list( - object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() - ) - - return obj_output[0] - - -def default_tp_concat_fn( - layer_name_mapping, - name, - train_params, - infer_params, - model_config, - hf_config=None, - convert_qkv_gate_up_by_simple_split=False, -): - """ - name: name of the parameter - train_params: training parameters - infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group - model_config: huggingface model_config - TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model - definition so that it is model-agnostic. If the model doesn't implement this function, - we can throw an error to force user disable TP HybridEngine. - """ - from megatron.core import mpu - - train_tp_size = mpu.get_tensor_model_parallel_world_size() - if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - num_attention_heads = model_config.num_attention_heads - num_key_value_heads = model_config.num_key_value_heads - if "vision_model" in name: - num_attention_heads = hf_config.vision_config.num_heads - num_key_value_heads = hf_config.vision_config.num_heads - assert num_attention_heads % num_key_value_heads == 0 - num_q_per_kv = num_attention_heads // num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( - f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" - ) - kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - for infer_param in infer_params: - num_query_groups_per_partition = num_key_value_heads // train_tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] - - elif ( - layer_name_mapping.get("gate_proj_layer_name") in name - and "layer_norm" not in name - and "vision_model.projection" not in name - ): - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in infer_params: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] - - elif "mlp.experts.linear_fc2.weight" in name: # moe - infer_params = torch.cat(infer_params, dim=1) - - else: - # concat tensor - infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params)) - - return infer_params - - -def per_tensor_generator( - actor_module, - model_config, - weight_converter, - transformer_config, - layer_name_mapping, - convert_qkv_gate_up_by_simple_split=True, -): - from megatron.core import parallel_state as mpu - - pp_rank = mpu.get_pipeline_model_parallel_rank() - ep_size = mpu.get_expert_model_parallel_world_size() - etp_size = mpu.get_expert_tensor_parallel_world_size() - ep_group = mpu.get_expert_model_parallel_group() - etp_group = mpu.get_expert_tensor_parallel_group() - vpp_size = len(actor_module) - all_gather_group = mpu.get_tensor_model_parallel_group() - all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) - - def tensor_generator(): - for scan_vpp_idx in range(vpp_size): - existing_keys = set() - model = unwrap_model(actor_module[scan_vpp_idx]) - for name, param in model.named_parameters(): - existing_keys.add(name) - yield name, param - # note - # there is a bug in megatron GPTModel - # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in - # state_dict(). for now we patch it by adding those keys to extra_keys. - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] - for name in extra_keys: - yield name, model.state_dict()[name].to(get_device_id()) - - # we need first make all rank get full model information - meta_info = [] - for scan_vpp_idx in range(vpp_size): - existing_keys = set() - model = unwrap_model(actor_module[scan_vpp_idx]) - for idx, (name, _) in enumerate(model.named_parameters()): - existing_keys.add(name) - meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] - for name in extra_keys: - meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - - obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object( - object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() - ) - layer_list_meta = [item for sublist in obj_spec_output for item in sublist] - - gen_func = tensor_generator() - - # lazy load tensor for full model - for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: - if model_config.tie_word_embeddings and ("output_layers" in name): - import warnings - - warnings.warn( - "Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2 - ) - continue - - if cur_pp_rank == pp_rank: - try: - cur_name, cur_tensor = next(gen_func) - except StopIteration: - cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config) - else: - cur_tensor, cur_name = None, None - - # pp broadcast model tensor and name - cur_name = broadcast_str_from_megatron_pp(cur_name) - broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor) - - # (xya): this is a hack to fix the name of the parameters - while cur_name.startswith("module."): - cur_name = cur_name[len("module.") :] - - # EP - if ".mlp.experts.linear_fc" in cur_name and ep_size > 1: - num_experts = weight_converter.mcore_config.num_moe_experts - num_experts_per_rank = num_experts // ep_size - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group) - - name_prefix, local_expert_id = cur_name.split(".weight") - local_expert_id = int(local_expert_id) - global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)] - global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] - - for name, param in zip(global_expert_names, infer_params, strict=True): - if etp_size > 1: - # gather etp - etp_params = [torch.empty_like(param) for _ in range(etp_size)] - torch.distributed.all_gather(etp_params, param, group=etp_group) - params = etp_params - else: - params = [param] - - merge_params = default_tp_concat_fn( - layer_name_mapping, - name, - broad_pp_tensor, - params, - model_config, - weight_converter.hf_config, - convert_qkv_gate_up_by_simple_split, - ) - if not isinstance(merge_params, list): - merge_params = [merge_params] - converted_names, converted_params = weight_converter.convert_param(name, merge_params) - - yield from zip(converted_names, [param.detach() for param in converted_params], strict=True) - continue - - # tp all gather - if tp_utils.is_tensor_parallel_param(broad_pp_tensor): - # allocate a new tensor with proper size - if all_gather_group_size <= 1: - infer_params = [broad_pp_tensor] - else: - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) - infer_params = default_tp_concat_fn( - layer_name_mapping, - cur_name, - broad_pp_tensor, - infer_params, - model_config, - weight_converter.hf_config, - convert_qkv_gate_up_by_simple_split, - ) - else: - infer_params = broad_pp_tensor - - if not isinstance(infer_params, list): - infer_params = [infer_params] - converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) - - yield from zip(converted_names, [param.detach() for param in converted_params], strict=True) - - -def get_transformer_layer_offset(pipeline_rank, vp_stage, config: TransformerConfig): - """ - Get the index offset of any pipeline stage, given the level of pipelining. - - Make pipeline_rank and vp_stage as two arguments to make it more flexible, - which is able to fetch layer offset for any pipeline stage. - The original function only returns the layer offset for current pipeline stage. - - Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset - """ - - has_vp_stage = ( - inspect.signature(parallel_state.is_pipeline_first_stage).parameters.get("vp_stage", None) is not None - ) - extra_kwargs = {} if not has_vp_stage else {"ignore_virtual": False, "vp_stage": vp_stage} - - if config.pipeline_model_parallel_size > 1: - if hasattr(config, "pipeline_model_parallel_layout") and config.pipeline_model_parallel_layout: - from megatron.core.transformer.enums import LayerType - - offset = config.pipeline_model_parallel_layout.get_layer_offset( - layer_type=LayerType.decoder, vp_stage=vp_stage - ) - elif ( - config.num_layers_in_first_pipeline_stage is not None - or config.num_layers_in_last_pipeline_stage is not None - ): - # Calculate number of pipeline stages to distribute the remaining Transformer - # layers after deducting the Transformer layers in the first or the last stages - middle_pipeline_stages = config.pipeline_model_parallel_size - middle_pipeline_stages -= sum( - [ - 1 if x is not None else 0 - for x in ( - config.num_layers_in_first_pipeline_stage, - config.num_layers_in_last_pipeline_stage, - ) - ] - ) - - # Calculate layers to distribute in each pipeline stage. If the - # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage - # are not set, we will not enable uneven pipeline. All layers will be treated - # as middle layers. - num_layers_in_first_pipeline_stage = ( - 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage - ) - num_layers_in_last_pipeline_stage = ( - 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage - ) - - middle_num_layers = ( - config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage - ) - - if (vp_size := config.virtual_pipeline_model_parallel_size) is not None: - assert vp_stage is not None, "vp_stage must be provided if virtual pipeline model parallel size is set" - - # Calculate number of layers in each virtual model chunk - # If the num_layers_in_first_pipeline_stage and - # num_layers_in_last_pipeline_stage are not set, all pipeline stages - # will be treated as middle pipeline stages in the calculation - num_layers_per_virtual_model_chunk_in_first_pipeline_stage = ( - 0 - if config.num_layers_in_first_pipeline_stage is None - else config.num_layers_in_first_pipeline_stage // vp_size - ) - - num_layers_per_virtual_model_chunk_in_last_pipeline_stage = ( - 0 - if config.num_layers_in_last_pipeline_stage is None - else config.num_layers_in_last_pipeline_stage // vp_size - ) - - num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size - - # First stage + middle stage + last stage - total_virtual_chunks = ( - num_layers_per_virtual_model_chunk_in_first_pipeline_stage - + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage - + num_layers_per_virtual_model_chunk_in_last_pipeline_stage - ) - - # Calculate the layer offset with interleaved uneven pipeline parallelism - if pipeline_rank == 0: - offset = vp_stage * total_virtual_chunks - else: - offset = ( - vp_stage * total_virtual_chunks - + num_layers_per_virtual_model_chunk_in_first_pipeline_stage - + (pipeline_rank - 1) - * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) - ) - else: - if middle_pipeline_stages > 0: - num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages - else: - num_layers_per_pipeline_rank = 0 - - middle_pipeline_rank = ( - pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 - ) - - if pipeline_rank == 0: - offset = 0 - else: - offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage - else: - num_layers = config.num_layers - - # Increase the number of layers by one if we include the embedding (loss) - # layer into pipeline parallelism partition and placement - if config.account_for_embedding_in_pipeline_split: - num_layers += 1 - - if config.account_for_loss_in_pipeline_split: - num_layers += 1 - - num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size - - if (vp_size := config.virtual_pipeline_model_parallel_size) is not None: - assert vp_stage is not None, "vp_stage must be provided if virtual pipeline model parallel size is set" - - num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size - total_virtual_chunks = num_layers // vp_size - offset = vp_stage * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) - - # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage( - **extra_kwargs - ): - offset -= 1 - else: - offset = pipeline_rank * num_layers_per_pipeline_rank - - # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage( - **extra_kwargs - ): - offset -= 1 - else: - offset = 0 - return offset - - def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer): from megatron.core.distributed import finalize_model_grads from megatron.core.utils import get_model_config @@ -1220,3 +510,11 @@ def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer): config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] + + +def mapping_string_to_attn_backend(args: dict) -> dict: + if "attention_backend" in args and isinstance(args["attention_backend"], str): + from megatron.core.transformer.enums import AttnBackend + + args["attention_backend"] = AttnBackend[args["attention_backend"]] + return args diff --git a/verl/utils/model.py b/verl/utils/model.py index b9d17cfbb0b..1d12b0053da 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -18,7 +18,6 @@ import json import os import re -import warnings from dataclasses import dataclass from typing import Optional @@ -36,13 +35,10 @@ AutoModelForTokenClassification, AutoModelForVision2Seq, GenerationConfig, - MistralForSequenceClassification, - PretrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast -from verl.models.registry import ModelRegistry from verl.utils.import_utils import is_trl_available @@ -323,136 +319,6 @@ def check_target_modules(config, key: str) -> bool: return target_module_found -def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): - """ - Transform the model name in each model_chunk in each pp stage into the name in inference engine - """ - from verl.utils.megatron_utils import get_transformer_layer_offset - - layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config) - - if layer_name in name: # belong to an intermediate layer - split_name = name.split(".") - # find the num next to split_name - for i, name in enumerate(split_name): - if name == layer_name: - break - layer_num_idx = i + 1 - # check the name - assert len(split_name) >= layer_num_idx + 1, f"split_name = {split_name}" - assert split_name[layer_num_idx].isdigit(), f"split_name = {split_name}" - # increment layer_num_idx by layer_offset - split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset) - name = ".".join(split_name) # weight name in inference_tp_model - return name - - -def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): - """ - Normalize the pp vpp params into a complete named parameters. - This is useful when gather parameters from pp ranks and passed to a model without pp - - params: Iterable[List[Dict[str, param]]] - params contains a list of pp, with a list of vpp named_parameters in each vpp chunk. - output: Dict[str, param] - - """ - pp_size = len(params) - for pp_rank in range(len(params)): - vpp_size = len(params[pp_rank]) - for vpp_rank in range(vpp_size): - for name, param in params[pp_rank][vpp_rank].items(): - normalized_name = normalize_model_name( - name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name - ) - yield normalized_name, param - - -def get_parallel_model_from_config( - config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False -): - from megatron.core import ModelParallelConfig - - assert isinstance(megatron_config, ModelParallelConfig) - model_class = _get_parallel_model_architecture_from_config(config, value) - - model = model_class( - config, - megatron_config, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - ) - return model - - -def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch, value) - print("after load model cls") - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. Supported architectures: " - f"{ModelRegistry.get_supported_archs()}" - ) - - -def _load_hf_model(config, model_config, is_value_model): - """Helper function containing the loading hf model logic""" - from accelerate import init_empty_weights - from megatron.core import parallel_state as mpu - - from verl.models.mcore.saver import _megatron_calc_global_rank - - assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" - architectures = getattr(model_config, "architectures", []) - - # get auto class - auto_cls = get_hf_auto_model_class(model_config) - - if config.model.path.startswith("hdfs:"): - from verl.utils.fs import copy_to_local - - print(f"start download from {config.model.path}") - local_model_path = copy_to_local(src=config.model.path, use_shm=config.model.get("use_shm", False)) - print("finish download") - else: - local_model_path = config.model.path - print(f"load from local dir {local_model_path}") - - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank()) - cpu_init_weights = lambda: torch.device("cpu") - init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - # TODO: to find a better way to load mistral7b-rm lm_head - if "mistral7b-rm" in config.model.path: - model = MistralForSequenceClassification.from_pretrained( - local_model_path, - torch_dtype="auto", - # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank - # low_cpu_mem_usage=True - ) # use score head instead of lm_head - state_dict = model.state_dict() - state_dict["lm_head.weight"] = state_dict["score.weight"] - state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ - :32000 - ] # workaround, 32001 -> 32000 - is_value_model = True - else: - model = auto_cls.from_pretrained( - local_model_path, - torch_dtype="auto", - # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank - # low_cpu_mem_usage=True - ) - state_dict = model.state_dict() - - return architectures, model, state_dict, is_value_model - - def get_hf_model_path(config): if config.model.path.startswith("hdfs:"): from verl.utils.fs import copy_to_local @@ -463,43 +329,6 @@ def get_hf_model_path(config): return local_model_path -def load_megatron_model_weights(config, model_config, parallel_model, params_dtype, is_value_model=False): - """Load weights for verl customized model.""" - architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model) - - from verl.models.weight_loader_registry import get_weight_loader - - print(f"before weight loader: architectures = {architectures}...") - for arch in architectures: - print(f"call weight loader arch = {arch}, model config = {model.config}") - weight_loader = get_weight_loader(arch) - weight_loader( - state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model, - tie_word_embeddings=model_config.tie_word_embeddings, - ) - return model.config - - -def load_megatron_gptmodel_weights(config, model_config, parallel_model, params_dtype, is_value_model=False): - """Load weights for mcore GPT model.""" - _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model) - - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - - load_state_dict_to_megatron_gptmodel( - state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model, - ) - del state_dict, model - - # pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size): """pad the tokens such that the total length is a multiple of size. @@ -537,8 +366,7 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False, prefix=""): from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.serialization import StrictHandling - - from verl.utils.megatron_utils import unwrap_model + from megatron.core.utils import unwrap_model # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED strict = StrictHandling.ASSUME_OK_UNEXPECTED @@ -553,42 +381,6 @@ def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=Fal return -def get_parallel_gptmodel_from_config( - tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False -): - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec - from megatron.core.models.gpt.gpt_model import GPTModel - - use_te = True - assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) - rope_scaling_args = {} - if hf_config.rope_scaling is not None: - assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] - parallel_model = GPTModel( - config=tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=hf_config.vocab_size, - max_sequence_length=hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type="rope", - rotary_base=hf_config.rope_theta, - **rope_scaling_args, - ) - # # for layer in parallel_model.decoder.layers: - # layer.self_attention.core_attention.flash_attention.softmax_scale = None - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - parallel_model.output_layer = LinearForLastLayer( - input_size=tfconfig.hidden_size, output_size=1, config=tfconfig - ) - return parallel_model - - def patch_valuehead_model(model) -> None: from types import MethodType diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 466402cab4e..17bba726cbf 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -33,6 +33,7 @@ # from megatron.core.optimizer import DistributedOptimizer from megatron.core.optimizer import DistributedOptimizer from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.utils import get_model_config, unwrap_model from omegaconf import OmegaConf from torch import nn @@ -49,7 +50,6 @@ set_router_replay_data, ) from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits -from verl.utils.megatron_utils import get_model_config, unwrap_model from verl.utils.profiler import GPUMemoryLogger from verl.utils.profiler.profile import Profiler from verl.utils.py_functional import append_to_dict diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index a8799c35691..0c7417a6989 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -99,7 +99,6 @@ class McoreEngineConfig(EngineConfig): seed (int): Random seed for reproducibility. override_ddp_config (dict[str, Any]): Override configuration for DDP. override_transformer_config (dict[str, Any]): Override configuration for transformer. - use_mbridge (bool): Whether to use MBridge for communication. dtype (str): Mixed precision training param dtype, default "bfloat16" """ @@ -120,7 +119,6 @@ class McoreEngineConfig(EngineConfig): override_ddp_config: dict[str, Any] = field(default_factory=dict) override_transformer_config: dict[str, Any] = field(default_factory=dict) override_mcore_model_config: dict[str, Any] = field(default_factory=dict) - use_mbridge: bool = False vanilla_mbridge: bool = True strategy: str = "megatron" diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 00292a73c2e..d5d2b680a9b 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -41,13 +41,11 @@ load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, - per_tensor_generator, register_megatron_training_hooks, ) from verl.utils.model import ( extract_multi_modal_inputs_tensordict, load_mcore_dist_weights, - load_megatron_gptmodel_weights, ) from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig @@ -110,70 +108,62 @@ def _init_device_mesh(self): ) def _build_tf_config(self): - from verl.models.mcore import hf_to_mcore_config - from verl.models.mcore.config_converter import mapping_string_to_attn_backend + from verl.utils.megatron_utils import mapping_string_to_attn_backend from verl.utils.torch_dtypes import PrecisionType self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype) - if self.param_dtype == torch.float16: - assert self.engine_config.use_mbridge, "fp16 mode requires use_mbridge to be True" self.dtype = PrecisionType.to_dtype(self.param_dtype) override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) - use_mbridge = self.engine_config.use_mbridge self.provider = None self.vanilla_bridge = self.engine_config.vanilla_mbridge - if use_mbridge: - if self.vanilla_bridge: - from verl.models.mcore.mbridge import AutoBridge - - bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) - bridge.set_extra_args(**override_transformer_config) - tf_config = bridge.config - tf_config.fp16 = self.param_dtype == torch.float16 - tf_config.bf16 = self.param_dtype == torch.bfloat16 - else: - from verl.models.mcore.bridge import AutoBridge - - # Use Megatron-Bridge to convert HF config to Megatron config - bridge = AutoBridge.from_hf_pretrained( - self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code - ) - # Get Megatron provider and configure it - provider = bridge.to_megatron_provider(load_weights=False) - - # In case of invalid overrides, we need to make sure some critical params are set correctly - provider.params_dtype = self.param_dtype - - # Pass distributed info - provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size - provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size - provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size - provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size - provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size - provider.context_parallel_size = self.engine_config.context_parallel_size - provider.sequence_parallel = self.engine_config.sequence_parallel - - # Match verl implementation (need variable_seq_lengths) - from megatron.core.transformer.enums import AttnBackend - - provider.attention_backend = AttnBackend.flash - provider.variable_seq_lengths = True - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_router_load_balancing_type = "none" - - # Apply transformer config overrides - for key, value in override_transformer_config.items(): - setattr(provider, key, value) - - provider.finalize() - self.provider = provider - tf_config = None # Will be set after model creation - self.bridge = bridge + if self.vanilla_bridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + tf_config.fp16 = self.param_dtype == torch.float16 + tf_config.bf16 = self.param_dtype == torch.bfloat16 else: - self.bridge = None - tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config) + from verl.models.mcore.bridge import AutoBridge + + # Use Megatron-Bridge to convert HF config to Megatron config + bridge = AutoBridge.from_hf_pretrained( + self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code + ) + # Get Megatron provider and configure it + provider = bridge.to_megatron_provider(load_weights=False) + + # In case of invalid overrides, we need to make sure some critical params are set correctly + provider.params_dtype = self.param_dtype + + # Pass distributed info + provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size + provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size + provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size + provider.context_parallel_size = self.engine_config.context_parallel_size + provider.sequence_parallel = self.engine_config.sequence_parallel + + # Match verl implementation (need variable_seq_lengths) + from megatron.core.transformer.enums import AttnBackend + + provider.attention_backend = AttnBackend.flash + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + + # Apply transformer config overrides + for key, value in override_transformer_config.items(): + setattr(provider, key, value) + + provider.finalize() + self.provider = provider + tf_config = None # Will be set after model creation + self.bridge = bridge if not self.bridge: self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) @@ -232,28 +222,14 @@ def _build_megatron_module(self): if self.engine_config.use_dist_checkpointing: load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model) else: - if self.bridge is not None: - if self.vanilla_bridge: - self.bridge.load_weights(module, self.model_config.local_path) - else: - allowed_mismatched_params = [] - if self.is_value_model: - allowed_mismatched_params = ["output_layer.weight"] - self.bridge.load_hf_weights( - module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params - ) + if self.vanilla_bridge: + self.bridge.load_weights(module, self.model_config.local_path) else: - # (vermouth1992) this is a workaround to be compatible with the old API - tmp_config = OmegaConf.create( - {"model": {"path": self.model_config.local_path, "use_shm": self.model_config.use_shm}} - ) - - load_megatron_gptmodel_weights( - tmp_config, - self.model_config.hf_config, - module, - params_dtype=self.dtype, - is_value_model=is_value_model, + allowed_mismatched_params = [] + if self.is_value_model: + allowed_mismatched_params = ["output_layer.weight"] + self.bridge.load_hf_weights( + module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params ) if torch.distributed.get_rank() == 0: @@ -562,16 +538,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw def get_per_tensor_param(self): if self._is_offload_param: load_megatron_model_to_gpu(self.module, load_grad=False) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.module) - else: - per_tensor_param = per_tensor_generator( - self.module, - self.model_config.hf_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) + per_tensor_param = self.bridge.export_weights(self.module) # TODO: support megatron LoRA return per_tensor_param, None diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index db2e3fb1b97..1ed4cf95f84 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -35,7 +35,6 @@ from megatron.core import parallel_state as mpu from verl import DataProto -from verl.models.mcore import get_mcore_weight_converter from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import hf_tokenizer @@ -57,11 +56,10 @@ load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, - per_tensor_generator, register_megatron_training_hooks, ) from verl.utils.memory_utils import aggressive_empty_cache -from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.model import get_hf_model_path, load_mcore_dist_weights from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -115,7 +113,6 @@ def _init_hf_config_and_tf_config( ): from transformers import AutoConfig - from verl.models.mcore import hf_to_mcore_config from verl.utils import hf_processor, hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import update_model_config @@ -154,65 +151,58 @@ def _init_hf_config_and_tf_config( if self.rank == 0: print(f"Model config after override: {hf_config}") - from verl.models.mcore.config_converter import mapping_string_to_attn_backend + from verl.utils.megatron_utils import mapping_string_to_attn_backend - # todo: remove this line after mcore adopt mbridge 0.15, now for compatibility override_transformer_config = mapping_string_to_attn_backend(override_transformer_config) fp16 = dtype == torch.float16 bf16 = dtype == torch.bfloat16 - if fp16: - assert megatron_config.use_mbridge, "fp16 mode requires use_mbridge to be True" self.provider = None self.vanilla_bridge = megatron_config.get("vanilla_mbridge", True) - if megatron_config.use_mbridge: - if self.vanilla_bridge: - from verl.models.mcore.mbridge import AutoBridge - - bridge = AutoBridge.from_config(hf_config, dtype=dtype) - bridge.set_extra_args(**override_transformer_config) - tf_config = bridge.config - tf_config.fp16 = fp16 - tf_config.bf16 = bf16 - else: - from verl.models.mcore.bridge import AutoBridge - - # Use Megatron-Bridge to convert HF config to Megatron config - bridge = AutoBridge.from_hf_pretrained(self.local_path, trust_remote_code=trust_remote_code) - # Get Megatron provider and configure it - provider = bridge.to_megatron_provider(load_weights=False) - - # In case of invalid overrides, we need to make sure some critical params are set correctly - provider.params_dtype = dtype - - # Pass distributed info - provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size - provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size - provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size - provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size - provider.virtual_pipeline_model_parallel_size = megatron_config.virtual_pipeline_model_parallel_size - provider.context_parallel_size = megatron_config.context_parallel_size - provider.sequence_parallel = megatron_config.sequence_parallel - - # Match verl implementation (need variable_seq_lengths) - from megatron.core.transformer.enums import AttnBackend - - provider.attention_backend = AttnBackend.flash - provider.variable_seq_lengths = True - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_router_load_balancing_type = "none" - - # Apply transformer config overrides - for key, value in override_transformer_config.items(): - setattr(provider, key, value) - - provider.finalize() - self.provider = provider - tf_config = None # Will be set after model creation - self.bridge = bridge + if self.vanilla_bridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config, dtype=dtype) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + tf_config.fp16 = fp16 + tf_config.bf16 = bf16 else: - tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) - self.bridge = None + from verl.models.mcore.bridge import AutoBridge + + # Use Megatron-Bridge to convert HF config to Megatron config + bridge = AutoBridge.from_hf_pretrained(self.local_path, trust_remote_code=trust_remote_code) + # Get Megatron provider and configure it + provider = bridge.to_megatron_provider(load_weights=False) + + # In case of invalid overrides, we need to make sure some critical params are set correctly + provider.params_dtype = dtype + + # Pass distributed info + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size + provider.virtual_pipeline_model_parallel_size = megatron_config.virtual_pipeline_model_parallel_size + provider.context_parallel_size = megatron_config.context_parallel_size + provider.sequence_parallel = megatron_config.sequence_parallel + + # Match verl implementation (need variable_seq_lengths) + from megatron.core.transformer.enums import AttnBackend + + provider.attention_backend = AttnBackend.flash + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + + # Apply transformer config overrides + for key, value in override_transformer_config.items(): + setattr(provider, key, value) + + provider.finalize() + self.provider = provider + tf_config = None # Will be set after model creation + self.bridge = bridge if torch.distributed.get_rank() == 0: if tf_config is not None: @@ -407,16 +397,11 @@ def _build_model_optimizer( prefix=self.config.actor.megatron.dist_checkpointing_prefix, ) else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - if self.vanilla_bridge: - self.bridge.load_weights(actor_module, local_model_path) - else: - self.bridge.load_hf_weights(actor_module, local_model_path) + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(actor_module, local_model_path) else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False - ) + self.bridge.load_hf_weights(actor_module, local_model_path) if self.rank == 0: print_model_size(actor_module[0]) @@ -448,16 +433,11 @@ def _build_model_optimizer( prefix=self.config.ref.megatron.dist_checkpointing_prefix, ) else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - if self.vanilla_bridge: - self.bridge.load_weights(ref_module, local_model_path) - else: - self.bridge.load_hf_weights(ref_module, local_model_path) + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(ref_module, local_model_path) else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False - ) + self.bridge.load_hf_weights(ref_module, local_model_path) log_gpu_memory_usage("After ref module init", logger=logger) return ref_module, self.hf_config @@ -658,8 +638,6 @@ def init_model(self): "gate_proj_layer_name": "linear_fc1.", } self.weight_converter = None - if not self.config.actor.megatron.use_mbridge: - self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) @@ -673,19 +651,10 @@ async def rollout_mode(self): load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False) log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger) - if self.bridge is not None: - if self.vanilla_bridge: - per_tensor_param = self.bridge.export_weights(self.actor.actor_module) - else: - per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) + if self.vanilla_bridge: + per_tensor_param = self.bridge.export_weights(self.actor.actor_module) else: - per_tensor_param = per_tensor_generator( - self.actor.actor_module, - self.actor_model_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) + per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) @@ -1102,17 +1071,12 @@ def _build_critic_model_optimizer( prefix=self.config.megatron.dist_checkpointing_prefix, ) else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - if self.vanilla_bridge: - self.bridge.load_weights(critic_module, local_model_path) - else: - self.bridge.load_hf_weights( - critic_module, local_model_path, allowed_mismatched_params=["output_layer.weight"] - ) + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(critic_module, local_model_path) else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True + self.bridge.load_hf_weights( + critic_module, local_model_path, allowed_mismatched_params=["output_layer.weight"] ) t1 = time.time() if torch.distributed.get_rank() == 0: @@ -1381,19 +1345,13 @@ def _build_rm_model(self, model_path, tokenizer, override_model_config, override prefix=self.config.megatron.dist_checkpointing_prefix, ) else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - if self.vanilla_bridge: - self.bridge.load_weights(reward_model, local_model_path) - else: - self.bridge.load_hf_weights( - reward_model, local_model_path, allowed_mismatched_params=["output_layer.weight"] - ) + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(reward_model, local_model_path) else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True + self.bridge.load_hf_weights( + reward_model, local_model_path, allowed_mismatched_params=["output_layer.weight"] ) - get_torch_device().empty_cache() return reward_model, self.hf_config