diff --git a/.github/workflows/utils_cpu_test.yml b/.github/workflows/cpu_unit_tests.yml similarity index 84% rename from .github/workflows/utils_cpu_test.yml rename to .github/workflows/cpu_unit_tests.yml index e3ec220d078..62015cb3d0e 100644 --- a/.github/workflows/utils_cpu_test.yml +++ b/.github/workflows/cpu_unit_tests.yml @@ -1,4 +1,4 @@ -name: utils_cpu_test +name: cpu_unit_tests on: # Trigger the workflow on push or pull request, @@ -13,7 +13,7 @@ on: - v0.* paths: - "**/*.py" - - .github/workflows/utils_cpu_test.yml + - .github/workflows/cpu_unit_tests.yml - "!recipe/**/*.py" # Cancel jobs on the same ref if a new one is triggered @@ -26,7 +26,7 @@ permissions: contents: read jobs: - utils_cpu_test: + cpu_unit_tests: runs-on: ubuntu-latest timeout-minutes: 10 # Increase this timeout value as needed strategy: @@ -41,7 +41,7 @@ jobs: - name: Install the current repository run: | pip install -e .[test] - - name: Running test protocol.py + - name: Running data proto test run: | cd tests pytest -s -x test_protocol.py @@ -53,3 +53,7 @@ jobs: run: | cd tests/trainer pytest -s -x . + - name: Running worker tests + run: | + cd tests/workers/reward_manager + pytest -s -x . diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 78a5d8cbce9..6dbf77ffc10 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -61,7 +61,7 @@ jobs: e2e_ppo_trainer_vllm: runs-on: [L20x8] - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -161,6 +161,14 @@ jobs: run: | ray stop --force LIGER=True bash tests/e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNELS=True bash tests/e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNEL=True FUSED_KERNEL_BACKEND=triton bash tests/e2e/ppo_trainer/run_model_reward.sh e2e_ppo_trainer_vllm_vlm: runs-on: [L20x8] @@ -181,13 +189,13 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install -e .[test,geo,vllm] + pip3 install -e .[test,gpu,vllm,geo,trl] # Geo3k - name: Prepare Geo3k dataset run: | ray stop --force python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E training tests on 8 L20 GPUs with rmpad using function rm + - name: Running Geo3k VLM GRPO E2E training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ @@ -197,6 +205,16 @@ jobs: SP_SIZE=2 \ bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Running Geo3k VLM PPO E2E training tests on 8 L20 GPUs with rmpad using function rm + run: | + ray stop --force + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ + ADV_ESTIMATOR=gae RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + SP_SIZE=2 \ + bash tests/e2e/ppo_trainer/run_function_reward.sh + e2e_ppo_trainer_sglang: runs-on: [L20x8] needs: pre_commit_for_ppo @@ -262,7 +280,7 @@ jobs: e2e_ppo_trainer_sglang_vlm: runs-on: [L20x8] needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -294,74 +312,24 @@ jobs: ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ bash tests/e2e/ppo_trainer/run_function_reward.sh - - e2e_ppo_trainer_fused_kernels_vllm: - runs-on: [L20x8] - needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0 - options: --gpus all --shm-size=50g # Visual dataloader requires large memory - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test,geo,vllm] - # Geo3k - - name: Prepare Geo3k dataset - run: | - ray stop --force - python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + - name: Running Geo3k VLM E2E with rmpad using torch fused kernel (Qwen2.5-VL) run: | ray stop --force FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ - GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ bash tests/e2e/ppo_trainer/run_function_reward.sh - - e2e_ppo_trainer_fused_kernels_sglang: - runs-on: [L20x8] - needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=50g # Visual dataloader requires large memory - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test,geo,gpu,sglang] - - name: Prepare Geo3k dataset - run: | - ray stop --force - python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + - name: Running Geo3k VLM E2E with rmpad using triton fused kernel (Qwen2.5-VL) run: | ray stop --force - FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + FUSED_KERNELS=True FUSED_KERNEL_BACKEND=triton \ + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ - bash tests/e2e/ppo_trainer/run_function_reward.sh \ No newline at end of file + bash tests/e2e/ppo_trainer/run_function_reward.sh diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml index cda8edf3ee2..4ead8d2ef25 100644 --- a/.github/workflows/kernels.yml +++ b/.github/workflows/kernels.yml @@ -38,7 +38,7 @@ permissions: contents: read jobs: - e2e_gsm8k_megatron: + kernels: runs-on: [L20x8] timeout-minutes: 40 # Increase this timeout value as needed env: @@ -59,4 +59,7 @@ jobs: pip3 install --no-deps -e .[test] - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption run: | - python3 tests/kernels/test_linear_cross_entropy.py \ No newline at end of file + python3 tests/kernels/test_linear_cross_entropy.py + - name: Testing LinearCrossEntropyTP Correction, Computation Time and Memory Consumption + run: | + torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/kernels/test_linear_cross_entropy_tp.py \ No newline at end of file diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index cd308c44d09..e7a9dc7c2be 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -12,17 +12,18 @@ Trainers drive the training loop. Introducing new trainer classes in case of new Core APIs ~~~~~~~~~~~~~~~~~ -.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer +.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer :members: __init__, init_workers, fit - .. automodule:: verl.utils.tokenizer :members: hf_tokenizer - .. automodule:: verl.trainer.ppo.core_algos :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty - .. automodule:: verl.trainer.ppo.reward :members: load_reward_manager, compute_reward, compute_reward_async + +.. autoclass:: verl.workers.reward_manager.NaiveRewardManager + +.. autoclass:: verl.workers.reward_manager.DAPORewardManager diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh new file mode 100644 index 00000000000..498ca61b4ad --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh @@ -0,0 +1,64 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +FUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=4096 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.model.path=Qwen/Qwen2-7B-Instruct \ + critic.model.enable_gradient_checkpointing=True \ + critic.use_dynamic_bsz=True \ + critic.ppo_max_token_len_per_gpu=98304 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.use_dynamic_bsz=True \ + reward_model.forward_max_token_len_per_gpu=98304 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/recipe/char_count/READMD.md b/recipe/char_count/README.md similarity index 100% rename from recipe/char_count/READMD.md rename to recipe/char_count/README.md diff --git a/recipe/char_count/train_grpo.sh b/recipe/char_count/train_grpo.sh index 566b1d6f9b0..3f008d050eb 100644 --- a/recipe/char_count/train_grpo.sh +++ b/recipe/char_count/train_grpo.sh @@ -40,5 +40,5 @@ python3 -m verl.trainer.main_ppo \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=2 \ - custom_reward_function.path=/home/chi/Developer/verl/recipe/char_count/reward_function.py \ + custom_reward_function.path=recipe/char_count/reward_function.py \ custom_reward_function.name=char_count_reward_function \ No newline at end of file diff --git a/recipe/char_count/train_sft.sh b/recipe/char_count/train_sft.sh index 202c86f0482..a9bf1d5babf 100644 --- a/recipe/char_count/train_sft.sh +++ b/recipe/char_count/train_sft.sh @@ -13,7 +13,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ data.max_length=256 \ data.train_batch_size=256 \ use_remove_padding=True \ - model.partial_pretrain=$HOME/models/SmolLM2-135M-Instruct \ + model.partial_pretrain=HuggingFaceTB/SmolLM2-135M-Instruct \ trainer.default_local_dir=$save_path \ trainer.project_name=char_count-sft \ trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \ diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 19b3a0d2c72..cb3d3365be3 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -155,7 +155,6 @@ def fit(self): new_batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index 6a29dfb7699..d2d7b6190f9 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -118,21 +118,12 @@ def run(self, config): role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager + from verl.workers.reward_manager import get_reward_manager_cls - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError + # Note(haibin.lin): please make sure custom reward managers are imported and + # registered via `verl.workers.reward_manager.register` + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) compute_score = get_custom_reward_fn(config) reward_fn = reward_manager_cls( diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 56989bf932f..23a3c440369 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -33,6 +33,8 @@ reward_model: ref_path: ${reward_model.model.path} use_remove_padding: True use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fused_kernel_options: + impl_backend: torch # triton, torch tokenizer_path: ${actor_rollout_ref.model.path} enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} ref_type: freeze diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index cb603b7a3ed..83979333a5f 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -77,6 +77,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): attention_mask=None, position_ids=position_ids_rmpad, use_cache=False, + return_dict=self.use_fused_kernels, ) if self.use_fused_kernels: @@ -100,6 +101,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): attention_mask=micro_batch["attention_mask"], position_ids=micro_batch["position_ids"], use_cache=False, + return_dict=self.use_fused_kernels, ) if self.use_fused_kernels: diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 5b9cf4b8f25..1b14cfc741e 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -129,11 +129,15 @@ def _build_reward_ref_model_optimizer(self, config): trust_remote_code=trust_remote_code, ) + fused_kernel_options = config.model.get("fused_kernel_options", None) + fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + apply_monkey_patch( model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_remove_padding=config.model.get("use_remove_padding", False), use_fused_kernels=config.model.get("use_fused_kernels", False), + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype diff --git a/recipe/spin/main_spin.py b/recipe/spin/main_spin.py index 679b78866be..17c54d95059 100644 --- a/recipe/spin/main_spin.py +++ b/recipe/spin/main_spin.py @@ -113,25 +113,12 @@ def run(self, config): role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "batch": - from verl.workers.reward_manager import BatchRewardManager - - reward_manager_cls = BatchRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager + from verl.workers.reward_manager import get_reward_manager_cls - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError + # Note(haibin.lin): please make sure custom reward managers are imported and + # registered via `verl.workers.reward_manager.register` + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) compute_score = get_custom_reward_fn(config) reward_kwargs = dict(config.reward_model.get("reward_kwargs", {})) diff --git a/recipe/spin/spin_trainer.py b/recipe/spin/spin_trainer.py index a7484399499..6e0edf1d87d 100644 --- a/recipe/spin/spin_trainer.py +++ b/recipe/spin/spin_trainer.py @@ -395,15 +395,6 @@ def __init__(self, if config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - # if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - # self.use_critic = True - # elif self.config.algorithm.adv_estimator in [ - # AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX, - # AdvantageEstimator.RLOO, AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE - # ]: - # self.use_critic = False - # else: - # raise NotImplementedError self.use_critic = False self._validate_config() self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index fac7bbbd1ac..3917550b187 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -266,7 +266,6 @@ def fit(self): reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) diff --git a/setup.py b/setup.py index e6d3dc5b4cc..14305e67300 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ "torch-memory-saver>=0.0.5", "torch==2.6.0", ] +TRL_REQUIRES = ["trl<=0.9.6"] extras_require = { "test": TEST_REQUIRES, @@ -64,6 +65,7 @@ "math": MATH_REQUIRES, "vllm": VLLM_REQUIRES, "sglang": SGLANG_REQUIRES, + "trl": TRL_REQUIRES, } diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index efb6620df5d..6d07870744f 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -20,6 +20,7 @@ ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} RM_PAD=${RM_PAD:-True} FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} USE_KL=${USE_KL:-False} CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} @@ -90,6 +91,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.strategy=${STRATEGY} \ diff --git a/tests/e2e/ppo_trainer/run_model_reward.sh b/tests/e2e/ppo_trainer/run_model_reward.sh index 4c11e7a27cc..5e401ad0087 100644 --- a/tests/e2e/ppo_trainer/run_model_reward.sh +++ b/tests/e2e/ppo_trainer/run_model_reward.sh @@ -11,6 +11,8 @@ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} RM_PAD=${RM_PAD:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend SP_SIZE=${SP_SIZE:-1} SEQ_BALANCE=${SEQ_BALANCE:-False} LIGER=${LIGER:-False} @@ -47,6 +49,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_liger="${LIGER}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \ diff --git a/tests/kernels/test_linear_cross_entropy.py b/tests/kernels/test_linear_cross_entropy.py index f0fd0e1a63d..8ad28936e97 100644 --- a/tests/kernels/test_linear_cross_entropy.py +++ b/tests/kernels/test_linear_cross_entropy.py @@ -29,23 +29,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import typing import torch import verl.utils.torch_functional as verl_F from verl.utils.experimental.torch_functional import FusedLinearForPPO +from verl.utils.kernel import linear_cross_entropy from verl.utils.torch_functional import logprobs_from_logits compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) fused_linear_for_ppo = FusedLinearForPPO() fused_linear_for_ppo.compile(dynamic=True) +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) -def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction="none") -> typing.List[torch.Tensor]: + +def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none") -> typing.List[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] @@ -55,10 +60,16 @@ def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch. return logprobs, entropy -def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: +def run_verl_original_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +) -> typing.List[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature # compute entropy entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) @@ -67,23 +78,27 @@ def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels # To be tested -def run_verl_torch_fused_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor): +def run_verl_torch_fused_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +): hidden = hidden.to(torch.float32) weight = weight.to(torch.float32) logprobs, entropy = fused_linear_for_ppo( hidden, weight, labels, + temperature=temperature, ) return logprobs.squeeze(0), entropy.squeeze(0) -MAX_TEST_CASES = 5 - - class TestLinearCrossEntropy: - def __init__(self, test_case_idx: int) -> None: + def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None: self.test_case_idx = test_case_idx + self.temperature = temperature def cleanup(self): torch.cuda.empty_cache() @@ -121,7 +136,8 @@ def generate_hyper(self): self.hidden_size = 4096 self.vocab_size = 102400 else: - raise ValueError(f"Invalid test case index: {test_case_idx}") + raise ValueError(f"Invalid test case index: {self.test_case_idx}") + assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." def generate_forward_inputs(self): hidden = torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() @@ -144,6 +160,8 @@ def verify_correctness(self, iterations=5): verl_backward_latency = list() verl_fused_forward_latency = list() verl_fused_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -153,30 +171,44 @@ def verify_correctness(self, iterations=5): hidden, weight, labels = self.generate_forward_inputs() start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels) + (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() verl_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(hidden, weight, labels) + (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + # backward g_entropy, g_logprobs = self.generate_backward_inputs() @@ -198,12 +230,28 @@ def verify_correctness(self, iterations=5): torch.cuda.synchronize() verl_fused_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) # remove first latency torch_forward_latency = torch_forward_latency[1:] @@ -212,6 +260,8 @@ def verify_correctness(self, iterations=5): verl_backward_latency = verl_backward_latency[1:] verl_fused_forward_latency = verl_fused_forward_latency[1:] verl_fused_backward_latency = verl_fused_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] print("\n[INFO]: Verified forward & backward correctness.") @@ -221,6 +271,8 @@ def verify_correctness(self, iterations=5): print(f"[INFO]: Backward pass: VeRL implementation average time: {sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") print(f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: {sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: {sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") def check_storage(self, method_name, run_forward): self.cleanup() @@ -229,7 +281,7 @@ def check_storage(self, method_name, run_forward): hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (logprobs, entropy) = run_forward(hidden, weight, labels) + (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") @@ -246,6 +298,7 @@ def check_storage_all(self): self.check_storage("Torch", run_torch_entropy) self.check_storage("VeRL", run_verl_original_entropy) self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy) + self.check_storage("Kernel", linear_cross_entropy) if __name__ == "__main__": diff --git a/tests/kernels/test_linear_cross_entropy_tp.py b/tests/kernels/test_linear_cross_entropy_tp.py new file mode 100644 index 00000000000..35f6f971f1e --- /dev/null +++ b/tests/kernels/test_linear_cross_entropy_tp.py @@ -0,0 +1,442 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import typing + +import torch +import torch.distributed as dist + +try: + from verl.utils.kernel import linear_cross_entropy +except ImportError: + # FIXME: remove these manually included paths + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) +finally: + from verl.utils.kernel import linear_cross_entropy + +import verl.utils.torch_functional as verl_F + +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) +VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False) + + +def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none") -> typing.List[torch.Tensor]: + # [num_tokens, vocab_size] + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32)) + logits /= temperature + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] + logprobs = torch.neg(logprobs) + return logprobs, entropy + + +class TorchEntropyTP(torch.autograd.Function): + """ + it is used for testing the correctness of the kernel + it is not efficient and is not recommended to use in practice + """ + + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, dist_process_group: torch.distributed.ProcessGroup): + # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size] + ctx.original_hidden_shape = hidden.shape + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size] + logits /= temperature + whole_logits = torch.empty((logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), dtype=logits.dtype, device=logits.device) + whole_logits_ref = [whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]] for i in range(dist.get_world_size(dist_process_group))] + dist.all_gather(whole_logits_ref, logits, group=dist_process_group) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + + logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.neg(logprobs) + + ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) + ctx.dist_process_group = dist_process_group + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): + hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors + dist_process_group = ctx.dist_process_group + temperature = ctx.temperature + batch_size, hidden_size = hidden.shape + vocab_size, hidden_size = weight.shape + rank = dist.get_rank(dist_process_group) + + # Compute softmax probabilities + maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) + exp_logits = torch.exp(whole_logits - maximum) + accumulate = exp_logits.sum(dim=-1, keepdim=True) + pd = exp_logits / accumulate + + # Gradient for entropy + # entropy = entropy_a - entropy_b + # entropy_a = log(sum(exp(logits))) + # entropy_b = sum(pd * logits) + # d_entropy_a/d_logits = pd + # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = d_entropy_a - d_entropy_b + # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) + d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + + # Gradient for logprobs + # logprobs = -cross_entropy = -log(pd[labels]) + # d_logprobs/d_logits = (pd - one_hot(labels)) + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + g_logprobs = torch.neg(g_logprobs) + d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) + # NOTE: This will lead to wrong result + # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot + + # Combine gradients + d_logits = d_logits_entropy + d_logits_logprobs + d_logits /= temperature + + # Get local slice of gradients + local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size] + + # Compute gradients for hidden and weight + d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32)) + d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32)) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return d_hidden, d_weight, None, None, None + + +run_torch_entropy_tp = TorchEntropyTP.apply + + +class TestLinearCrossEntropy_TensorParallel: + def __init__(self): + dist.init_process_group(backend="nccl") + self.group = dist.group.WORLD + + self.local_rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") + + def initialize(self, test_case_idx: int, temperature: float = 1.5): + self.test_case_idx = test_case_idx + self.temperature = temperature + + def shutdown(self): + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.dtype = torch.bfloat16 + if self.test_case_idx == 0: + self.batch_size = 1 + self.num_tokens = 1937 + self.hidden_size = 3584 + self.vocab_size = 152064 + elif self.test_case_idx == 1: + self.batch_size = 1 + self.num_tokens = 2169 + self.hidden_size = 896 + self.vocab_size = 151936 + elif self.test_case_idx == 2: + self.batch_size = 1 + self.num_tokens = 1530 + self.hidden_size = 2048 + self.vocab_size = 32256 + elif self.test_case_idx == 3: + self.batch_size = 1 + self.num_tokens = 1388 + self.hidden_size = 4096 + self.vocab_size = 102400 + elif self.test_case_idx == 4: + self.batch_size = 1 + self.num_tokens = 8192 + self.hidden_size = 4096 + self.vocab_size = 102400 + else: + raise ValueError(f"Invalid test case index: {self.test_case_idx}") + assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." + + def generate_forward_inputs(self): + hidden = torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() + weight = torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() + labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) + g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) + return g_entropy, g_logprobs + + def verify_torch_itself(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + # forward pass + # Create a tensor to hold the gathered weights from all ranks + # weight has shape [vocab_size, hidden_size] + # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size] + + # Create a single contiguous tensor to hold all gathered weights + whole_weight = torch.empty((self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device) + + # Create views into the tensor for each rank's portion + whole_weight_views = [whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)] + + # Perform all_gather operation using the views + dist.all_gather(whole_weight_views, weight, group=self.group) + + # Set requires_grad for autograd + whole_weight.requires_grad_() + + (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) + + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + + torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + (single_d_hidden, single_d_weight) = torch.autograd.grad((single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False) + + (tp_d_hidden, tp_d_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + # Extract the corresponding slice from single_d_weight for comparison + # tp_d_weight has shape [vocab_size, hidden_size] + # single_d_weight has shape [vocab_size * world_size, hidden_size] + torch.testing.assert_close(tp_d_weight, single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], atol=1e-2, rtol=1e-4) + + # atol=1e-3, rtol=1e-4) + if self.local_rank == 0: + print("[PASS] torch TP correctness is verified") + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + torch.cuda.synchronize() + forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_tp_hidden, d_tp_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + torch.cuda.synchronize() + backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") + print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + + def verify_kernel_correctness(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature, "none", self.group) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + start_event.record() + (torch_d_hidden, torch_d_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + start_event.record() + (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + if self.local_rank == 0: + print("\n[PASS]: Verified kernel forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: {sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: {sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: {sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: {sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature, "none", self.group) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + +if __name__ == "__main__": + # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py + + # Check if running with torchrun (distributed mode) + assert int(os.environ["WORLD_SIZE"]) > 1, "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to execute this script." + torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) + + # set_backward_method(BackwardEnum._Total_Fuse_MN) + # set_backward_method(BackwardEnum._Split_Dlogits_N) + + test = TestLinearCrossEntropy_TensorParallel() + for test_case_idx in range(MAX_TEST_CASES): + print(f"[INFO] Running test case {test_case_idx}") + test.initialize(test_case_idx) + if VERIFY_TORCH_SELF: + test.verify_torch_itself() + test.check_torch_storage() + test.verify_kernel_correctness() + test.check_kernel_storage() + + test.shutdown() diff --git a/tests/trainer/ppo/test_core_algos.py b/tests/trainer/ppo/test_core_algos.py new file mode 100644 index 00000000000..d285d15f06a --- /dev/null +++ b/tests/trainer/ppo/test_core_algos.py @@ -0,0 +1,126 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +import verl.trainer.ppo.core_algos +from verl.trainer.ppo.core_algos import get_adv_estimator_fn, register_adv_est + + +def mock_test_fn(): + pass + +class TestRegisterAdvEst(unittest.TestCase): + def setUp(self): + """Clear the registry before each test""" + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = { + "gae": lambda x: x * 2, + "vtrace": lambda x: x + 1, + } + self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY + + def tearDown(self) -> None: + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() + return super().tearDown() + + def test_register_new_function(self): + """Test registering a new function with a string name""" + @register_adv_est("test_estimator") + def test_fn(): + pass + + self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn) + + def test_register_with_enum(self): + """Test registering with an enum value (assuming AdvantageEstimator exists)""" + from enum import Enum + class AdvantageEstimator(Enum): + TEST = "test_enum_estimator" + + @register_adv_est(AdvantageEstimator.TEST) + def test_fn(): + pass + + self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn) + + def test_duplicate_registration_same_function(self): + """Test that registering the same function twice doesn't raise an error""" + register_adv_est("duplicate_test")(mock_test_fn) + register_adv_est("duplicate_test")(mock_test_fn) + + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn) + + def test_duplicate_registration_different_function(self): + """Test that registering different functions with same name raises ValueError""" + @register_adv_est("conflict_test") + def test_fn1(): + pass + + with self.assertRaises(ValueError): + @register_adv_est("conflict_test") + def test_fn2(): + pass + + def test_decorator_preserves_function(self): + """Test that the decorator returns the original function""" + def test_fn(): + return "original" + + decorated = register_adv_est("preserve_test")(test_fn) + self.assertEqual(decorated(), "original") + + def test_multiple_registrations(self): + """Test registering multiple different functions""" + init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY) + @register_adv_est("estimator1") + def fn1(): + pass + + @register_adv_est("estimator2") + def fn2(): + pass + + self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2) + + def test_get_adv_estimator_fn_valid_names(self): + """Test that valid names return the correct function from registry.""" + # Test GAE + gae_fn = get_adv_estimator_fn("gae") + assert gae_fn(5) == 10 # 5 * 2 = 10 + + # Test Vtrace + vtrace_fn = get_adv_estimator_fn("vtrace") + assert vtrace_fn(5) == 6 # 5 + 1 = 6 + + def test_get_adv_estimator_fn_invalid_name(self): + """Test that invalid names raise ValueError.""" + with pytest.raises(ValueError) as excinfo: + get_adv_estimator_fn("invalid_name") + assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value) + + def test_get_adv_estimator_fn_case_sensitive(self): + """Test that name lookup is case-sensitive.""" + with pytest.raises(ValueError): + get_adv_estimator_fn("GAE") # Different case + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/workers/reward_manager/test_registry.py b/tests/workers/reward_manager/test_registry.py new file mode 100644 index 00000000000..6e458fe5737 --- /dev/null +++ b/tests/workers/reward_manager/test_registry.py @@ -0,0 +1,86 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +# Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module +from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register + + +@pytest.fixture +def setup(): + """Setup test cases with a mock registry.""" + REWARD_MANAGER_REGISTRY.clear() + REWARD_MANAGER_REGISTRY.update({ + "manager1": "Manager1Class", + "manager2": "Manager2Class" + }) + return REWARD_MANAGER_REGISTRY + +def test_get_existing_manager(setup): + """Test getting an existing reward manager class.""" + assert get_reward_manager_cls("manager1") == "Manager1Class" + assert get_reward_manager_cls("manager2") == "Manager2Class" + +def test_get_nonexistent_manager(setup): + """Test getting a non-existent reward manager raises ValueError.""" + with pytest.raises(ValueError) as excinfo: + get_reward_manager_cls("unknown_manager") + assert "Unknown reward manager: unknown_manager" in str(excinfo.value) + +def test_case_sensitivity(setup): + """Test that manager names are case-sensitive.""" + with pytest.raises(ValueError): + get_reward_manager_cls("MANAGER1") + with pytest.raises(ValueError): + get_reward_manager_cls("Manager1") + +def test_empty_registry(setup): + """Test behavior when registry is empty.""" + REWARD_MANAGER_REGISTRY.clear() + with pytest.raises(ValueError) as excinfo: + get_reward_manager_cls("any_manager") + assert "Unknown reward manager: any_manager" in str(excinfo.value) + +def test_register_new_class(setup): + """Test registering a new class with the decorator.""" + @register("test_manager") + class TestManager: + pass + + assert "test_manager" in REWARD_MANAGER_REGISTRY + assert REWARD_MANAGER_REGISTRY["test_manager"] == TestManager + +def test_register_different_classes_same_name(setup): + """Test that registering different classes with same name raises ValueError.""" + @register("conflict_manager") + class Manager1: + pass + + with pytest.raises(ValueError) as context: + @register("conflict_manager") + class Manager2: + pass + + assert REWARD_MANAGER_REGISTRY["conflict_manager"] == Manager1 + +def test_decorator_returns_original_class(setup): + """Test that the decorator returns the original class unchanged.""" + @register("return_test") + class OriginalClass: + def method(setup): + return 42 + + assert OriginalClass().method() == 42 + assert REWARD_MANAGER_REGISTRY["return_test"] == OriginalClass diff --git a/verl/models/transformers/dense_common.py b/verl/models/transformers/dense_common.py new file mode 100644 index 00000000000..de0083dc20d --- /dev/null +++ b/verl/models/transformers/dense_common.py @@ -0,0 +1,191 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> CausalLMOutputWithPast: + r""" + Copy paste LLaMa's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py + + This function should be generic enough for all pure text models. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def forward_with_triton_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index f44252bb7b8..220e83ef07a 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -13,8 +13,7 @@ # limitations under the License. import sys -from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import torch @@ -25,7 +24,6 @@ from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -230,84 +228,3 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - - -@dataclass -class CausalLMOutputForPPO(CausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None - - -def forward_for_ppo( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputForPPO]: - r""" - Copy paste LLaMa's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py - - This function should be generic enough for all pure text models. - ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") - - fused_linear_for_ppo = FusedLinearForPPO() - log_probs, entropy = fused_linear_for_ppo.forward( - hidden_states=hidden_states, - vocab_weights=self.lm_head.weight, - input_ids=rolled_labels, - temperature=temperature, - ) - - return CausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index c06b237d9cd..2e4ccf220e6 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -25,6 +25,7 @@ from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_utils import PreTrainedModel +from verl.utils.import_utils import is_trl_available from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, @@ -138,12 +139,64 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs): print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Choose the forward function based on the model and backend. + Args: + model (PreTrainedModel): The model to apply the monkey patch. + use_fused_kernels (bool): Whether to use fused kernels. + fused_kernels_backend (str): The backend to use for fused kernels. + """ + if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: + print(f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is {use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}") + return + + forward_with_torch_backend_function = model.__class__.forward + forward_with_triton_backend_function = model.__class__.forward + if model.config.model_type == "qwen2_5_vl": + from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type == "qwen2_vl": + from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + else: + from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + + if fused_kernels_backend == "triton": + model.__class__.forward = forward_with_triton_backend_function + print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") + elif fused_kernels_backend == "torch": + model.__class__.forward = forward_with_torch_backend_function + print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") + else: + raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + + def apply_monkey_patch( model: PreTrainedModel, ulysses_sp_size: int = 1, use_remove_padding: bool = True, use_fused_kernels: bool = False, + fused_kernels_backend: str = None, ): + """ + Apply monkey patch to the models for ulysses sequence parallel and fused kernel. + + In the end of this function forward function of the model is patched for fused kernel. + If the model is not supported with fused kernel, please return after patch. + """ + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" module = sys.modules[model.__module__] @@ -151,16 +204,25 @@ def apply_monkey_patch( num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads except AttributeError: num_attention_heads, num_key_value_heads = model.config.text_config.num_attention_heads, model.config.text_config.num_key_value_heads - + assert num_attention_heads % ulysses_sp_size == 0, f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness." ) + + if is_trl_available(): + from trl import AutoModelForCausalLMWithValueHead + + def state_dict(self, *args, **kwargs): + return torch.nn.Module.state_dict(self, *args, **kwargs) + + AutoModelForCausalLMWithValueHead.state_dict = state_dict + print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") + # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type == "qwen2_5_vl": from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLFlashAttention2, - Qwen2_5_VLForConditionalGeneration, ) if use_remove_padding or ulysses_sp_size > 1: @@ -172,22 +234,16 @@ def apply_monkey_patch( if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) else: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel - patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) - - if use_fused_kernels: - from verl.models.transformers.qwen2_5_vl import forward_for_ppo - Qwen2_5_VLForConditionalGeneration.forward = forward_for_ppo - - return + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) elif model.config.model_type == "qwen2_vl": from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLFlashAttention2, - Qwen2VLForConditionalGeneration, ) if use_remove_padding or ulysses_sp_size > 1: @@ -199,17 +255,12 @@ def apply_monkey_patch( if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) else: from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel - patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) - - if use_fused_kernels: - from verl.models.transformers.qwen2_vl import forward_for_ppo - Qwen2VLForConditionalGeneration.forward = forward_for_ppo - - return + patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) elif model.config.model_type == "kimi_vl": if use_remove_padding or ulysses_sp_size > 1: @@ -219,7 +270,7 @@ def apply_monkey_patch( module.KimiVLForConditionalGeneration._merge_with_image_features = _merge_with_image_features module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in KimiVL") - + if use_fused_kernels: print("Not support fused kernels for KimiVL") @@ -237,10 +288,7 @@ def apply_monkey_patch( flash_attention._flash_attention_forward = _ulysses_flash_attention_forward print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") - if use_fused_kernels: - from verl.models.transformers.llama import forward_for_ppo - - model.__class__.forward = forward_for_ppo + patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) @lru_cache diff --git a/verl/models/transformers/qwen2_5_vl.py b/verl/models/transformers/qwen2_5_vl.py index ac4621ec5e4..30df487adc8 100644 --- a/verl/models/transformers/qwen2_5_vl.py +++ b/verl/models/transformers/qwen2_5_vl.py @@ -28,14 +28,13 @@ class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast): entropy: Optional[torch.FloatTensor] = None -def forward_for_ppo( +def forward_base_model( self: Qwen2_5_VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -47,19 +46,13 @@ def forward_for_ppo( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" Copy paste Qwen2_5_VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: @@ -70,9 +63,7 @@ def forward_for_ppo( n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) + raise ValueError(f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}") mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) @@ -88,9 +79,7 @@ def forward_for_ppo( n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) + raise ValueError(f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}") mask = input_ids == self.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) @@ -138,11 +127,57 @@ def forward_for_ppo( return_dict=return_dict, cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) hidden_states = outputs[0] if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") + raise NotImplementedError("forward_with_torch_backend has to return_dict") # Loss calculations if labels is not None: @@ -150,7 +185,7 @@ def forward_for_ppo( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -168,3 +203,79 @@ def forward_for_ppo( attentions=outputs.attentions, rope_deltas=rope_deltas, ) + + +def forward_with_triton_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + reduction="none", + temperature=temperature, + ) + + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index a7ae346ec26..b2b9db11165 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -298,14 +298,13 @@ class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): entropy: Optional[torch.FloatTensor] = None -def forward_for_ppo( +def forward_base_model( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -316,19 +315,13 @@ def forward_for_ppo( video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: +) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Copy paste Qwen2VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: @@ -339,15 +332,8 @@ def forward_for_ppo( n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + raise ValueError(f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}") + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -357,15 +343,8 @@ def forward_for_ppo( n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + raise ValueError(f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}") + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -401,10 +380,55 @@ def forward_for_ppo( cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + hidden_states = outputs[0] if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") + raise NotImplementedError("forward_with_torch_backend has to return_dict") # Loss calculations if labels is not None: @@ -412,7 +436,7 @@ def forward_for_ppo( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -430,3 +454,77 @@ def forward_for_ppo( attentions=outputs.attentions, rope_deltas=rope_deltas, ) + + +def forward_with_triton_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: + from verl.utils.kernel import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + reduction="none", + temperature=temperature, + ) + + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 6ec1668db77..e913e3d2cff 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -204,6 +204,7 @@ def _setup_env_cuda_visible_devices(self): else: cuda_val = val os.environ["CUDA_VISIBLE_DEVICES"] = val + os.environ["HIP_VISIBLE_DEVICES"] = val if rocr_val: # You must take care if both HIP/CUDA and ROCR env vars are set as they have diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index fa9299dd613..dfd9f2245a7 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -133,11 +133,16 @@ actor_rollout_ref: # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) use_fused_kernels: false + + # Options for fused kernels. If use_fused_kernels is true, this will be used. + fused_kernel_options: + + # Implementation backend for fused kernels. Options: "triton" or "torch". + impl_backend: torch # Whether to enable loading a remote code model trust_remote_code: false - # configs for the actor actor: # fsdp, fsdp2 or megatron. fsdp backend used here. diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 3fb88adc1f2..ec74b4e7abb 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -27,17 +27,26 @@ def main(config): run_ppo(config) +# Define a function to run the PPO-like training process def run_ppo(config) -> None: + # Check if Ray is not initialized if not ray.is_initialized(): - # this is for local ray cluster + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration ray.init( runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true"}}, num_cpus=config.ray_init.num_cpus, ) + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete runner = TaskRunner.remote() ray.get(runner.run.remote(config)) - # create a timeline trace file to analyze the performance + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis timeline_json_file = config.ray_init.get("timeline_json_file", None) if timeline_json_file: ray.timeline(filename=timeline_json_file) @@ -46,27 +55,29 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: def run(self, config): - # print initial config + # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint from omegaconf import OmegaConf from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) - # download the checkpoint from hdfs + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on local_path = copy_to_local(config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)) - # instantiate tokenizer + # Instantiate the tokenizer and processor. from verl.utils import hf_processor, hf_tokenizer trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # used for multimodal LLM, could be none + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) - # vllm early verify + # Version validation for vllm. if config.actor_rollout_ref.rollout.name in ["vllm"]: from verl.utils.vllm_utils import is_version_ge @@ -74,7 +85,7 @@ def run(self, config): if not is_version_ge(pkg="vllm", minver="0.7.3"): raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3") - # define worker classes + # Define worker classes based on the actor strategy. if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: assert config.critic.strategy in ["fsdp", "fsdp2"] from verl.single_controller.ray import RayWorkerGroup @@ -96,11 +107,14 @@ def run(self, config): from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + # Map roles to their corresponding remote worker classes. role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), Role.Critic: ray.remote(CriticWorker), } + # Define the resource pool specification. + # Map roles to the resource pool. global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, @@ -110,12 +124,12 @@ def run(self, config): Role.Critic: global_pool_id, } - # we should adopt a multi-source reward function here + # We should adopt a multi-source reward function here: # - for rule-based rm, we directly call a reward score # - for model-based rm, we call a model # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data + # finally, we combine all the rewards together + # The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy in ["fsdp", "fsdp2"]: from verl.workers.fsdp_workers import RewardModelWorker @@ -126,20 +140,24 @@ def run(self, config): role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) mapping[Role.RewardModel] = global_pool_id - # use reference model + # Add a reference policy worker if KL loss or KL reward is used. if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id + # Load the reward manager for training and validation. reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn + # Create training and validation datasets. train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. trainer = RayPPOTrainer( config=config, tokenizer=tokenizer, @@ -155,7 +173,9 @@ def run(self, config): train_sampler=train_sampler, device_name=config.trainer.device, ) + # Initialize the workers of the trainer. trainer.init_workers() + # Start the training process. trainer.fit() @@ -163,6 +183,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): """Create a dataset. Arguments: + data_paths: List of paths to data files. data_config: The data config. tokenizer (Tokenizer): The tokenizer. processor (Processor): The processor. @@ -174,16 +195,22 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): from verl.utils.dataset.rl_dataset import RLHFDataset + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: from verl.utils.import_utils import load_extern_type + # Dynamically load the custom dataset class dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset if not issubclass(dataset_cls, Dataset): raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset") else: + # Use the default RLHFDataset class if no custom class is specified dataset_cls = RLHFDataset print(f"Using dataset class: {dataset_cls.__name__}") + # Instantiate the dataset using the determined dataset class dataset = dataset_cls( data_files=data_paths, tokenizer=tokenizer, @@ -207,12 +234,14 @@ def create_rl_sampler(data_config, dataset): import torch from torch.utils.data import RandomSampler, SequentialSampler - # use sampler for better ckpt resume + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. if data_config.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. sampler = SequentialSampler(data_source=dataset) return sampler diff --git a/verl/trainer/ppo/__init__.py b/verl/trainer/ppo/__init__.py index 1ce90c5eb35..7a7aadbc9d9 100644 --- a/verl/trainer/ppo/__init__.py +++ b/verl/trainer/ppo/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. \ No newline at end of file diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index ac9344cddcb..c678b241e25 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -15,16 +15,70 @@ """ Core functions to implement PPO algorithms. The function implemented in this file should be used by trainer with different distributed strategies to -implement PPO +implement PPO-like algorithms. """ +__all__ = ['register', "get_adv_estimator_fn", "AdvantageEstimator"] + from collections import defaultdict +from enum import Enum import numpy as np import torch import verl.utils.torch_functional as verl_F +ADV_ESTIMATOR_REGISTRY = {} + +def register_adv_est(name_or_enum): + """Decorator to register a advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + """ + def decorator(fn): + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: + raise ValueError(f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}") + ADV_ESTIMATOR_REGISTRY[name] = fn + return fn + return decorator + +def get_adv_estimator_fn(name_or_enum): + """Get the advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + Returns: + `(callable)`: The advantage estimator function. + """ + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name not in ADV_ESTIMATOR_REGISTRY: + raise ValueError(f"Unknown advantage estimator simply: {name}") + return ADV_ESTIMATOR_REGISTRY[name] + +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + class AdaptiveKLController: """ @@ -63,7 +117,7 @@ def get_kl_controller(kl_ctrl): else: raise NotImplementedError - +@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") def compute_gae_advantage_return( token_level_rewards: torch.Tensor, values: torch.Tensor, @@ -110,6 +164,7 @@ def compute_gae_advantage_return( # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") def compute_grpo_outcome_advantage( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, @@ -165,13 +220,15 @@ def compute_grpo_outcome_advantage( return scores, scores - +@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") def compute_grpo_passk_outcome_advantage( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, norm_adv_by_std_in_grpo: bool = True, + config = None, + **kwargs, ): """ Compute advantage for Pass@k using a GRPO-style outcome reward formulation. @@ -184,12 +241,15 @@ def compute_grpo_passk_outcome_advantage( response_mask: (bs, response_length) index: (bs,) → group ID per sample epsilon: float for numerical stability - norm_adv_by_std_in_grpo: if True, normalize advantage by std within group + config: (dict) algorithm settings, which contains "norm_adv_by_std_in_grpo" Returns: advantages: (bs, response_length) returns: (bs, response_length) """ + assert config is not None + # if True, normalize advantage by std within group + norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) scores = token_level_rewards.sum(dim=-1) # (bs,) advantages = torch.zeros_like(scores) @@ -219,8 +279,9 @@ def compute_grpo_passk_outcome_advantage( advantages = advantages.unsqueeze(-1) * response_mask return advantages, advantages - -def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6): +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE) # or simply: @register_adv_est("reinforce_plus_plus_baseline") +def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, + epsilon: float = 1e-6, config=None, **kwargs): """ Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward (with only one scalar reward for each response). @@ -230,6 +291,7 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (dict) algorithm config Returns: advantages: `(torch.Tensor)` @@ -262,8 +324,9 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: return scores, scores - -def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6): +@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") +def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, + epsilon: float = 1e-6, config=None, **kwargs): """ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 @@ -272,6 +335,7 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_m shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (dict) algorithm config Returns: advantages: `(torch.Tensor)` @@ -303,8 +367,9 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_m return scores, scores - -def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6): +@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") +def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, + config=None, **kwargs): """ Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 @@ -313,6 +378,7 @@ def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_ma shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (dict) algorithm config Returns: advantages: `(torch.Tensor)` @@ -348,8 +414,8 @@ def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_ma return scores, scores - -def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor): +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config=None, **kwargs): """ Compute advantage for REINFORCE++. This implementation is based on the paper: https://arxiv.org/abs/2501.03262 @@ -359,6 +425,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (dict) algorithm config Returns: advantages: `(torch.Tensor)` @@ -366,7 +433,8 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten Returns: `(torch.Tensor)` shape: (bs, response_length) """ - + assert config is not None + gamma = config.gamma with torch.no_grad(): returns = torch.zeros_like(token_level_rewards) running_return = 0 @@ -382,8 +450,8 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten return advantages, returns - -def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor): +@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") +def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor, config=None, **kwargs): """ Compute advantage for ReMax, operating only on Outcome reward This implementation is based on the paper: https://arxiv.org/abs/2310.10505 @@ -396,6 +464,7 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba shape: (bs,) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (dict) algorithm config Returns: advantages: `(torch.Tensor)` diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 2bf21bf2ce0..e1fcbf11342 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -42,7 +42,7 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos -from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss from verl.trainer.ppo.metric_utils import ( compute_data_metrics, compute_throughout_metrics, @@ -58,7 +58,6 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger -from verl.workers.rollout.async_server import AsyncLLMServerManager WorkerType = Type[Worker] @@ -77,21 +76,6 @@ class Role(Enum): ActorRolloutRef = 6 -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" - REMAX = "remax" - RLOO = "rloo" - OPO = "opo" - GRPO_PASSK = "grpo_passk" - - @dataclass class ResourcePoolManager: """ @@ -212,7 +196,7 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, **kwargs): +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, config=None): """Compute advantage estimates for policy optimization. This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. @@ -226,6 +210,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. Returns: DataProto: The updated data with computed advantages and returns. @@ -234,8 +219,8 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re if "response_mask" not in data.batch.keys(): data.batch["response_mask"] = compute_response_mask(data) # prepare response group - # TODO: add other ways to estimate advantages if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) advantages, returns = core_algos.compute_gae_advantage_return( token_level_rewards=data.batch["token_level_rewards"], values=data.batch["values"], @@ -245,19 +230,21 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ) data.batch["advantages"] = advantages data.batch["returns"] = returns - if kwargs.get("use_pf_ppo", False): + if config.get("use_pf_ppo", False): data = core_algos.compute_pf_ppo_reweight_data( data, - kwargs.get("pf_ppo_reweight_method", "pow"), - kwargs.get("pf_ppo_weight_pow", 2.0), + config.get("pf_ppo_reweight_method", "pow"), + config.get("pf_ppo_weight_pow", 2.0), ) elif adv_estimator == AdvantageEstimator.GRPO: - # TODO: test on more adv estimator type + # Initialize the mask for GRPO calculation grpo_calculation_mask = data.batch["response_mask"] if multi_turn: # If multi-turn, replace the mask with the relevant part of loss_mask - response_length = grpo_calculation_mask.size(1) # Get length from the initial response mask - grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # This mask is the one intended for GRPO + # Get length from the initial response mask + response_length = grpo_calculation_mask.size(1) + # This mask is the one intended for GRPO + grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # Call compute_grpo_outcome_advantage with parameters matching its definition advantages, returns = core_algos.compute_grpo_outcome_advantage( token_level_rewards=data.batch["token_level_rewards"], @@ -267,58 +254,22 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ) data.batch["advantages"] = advantages data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO_PASSK: - advantages, returns = core_algos.compute_grpo_passk_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE: - advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - response_mask=data.batch["response_mask"], - ) + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = {"token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs['index'] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch:# optional + adv_kwargs['reward_baselines'] = data.batch["reward_baselines"] + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) data.batch["advantages"] = advantages data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.OPO: - advantages, returns = core_algos.compute_opo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError return data @@ -817,6 +768,8 @@ def init_workers(self): # create async rollout manager and request scheduler self.async_rollout_mode = False if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.workers.rollout.async_server import AsyncLLMServerManager + self.async_rollout_mode = True self.async_rollout_manager = AsyncLLMServerManager( config=self.config.actor_rollout_ref, @@ -1096,7 +1049,6 @@ def fit(self): reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) @@ -1119,9 +1071,7 @@ def fit(self): num_repeat=self.config.actor_rollout_ref.rollout.n, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable, - use_pf_ppo=self.config.algorithm.use_pf_ppo, - pf_ppo_reweight_method=self.config.algorithm.pf_ppo.reweight_method, - pf_ppo_weight_pow=self.config.algorithm.pf_ppo.weight_pow, + config=self.config.algorithm ) # update critic diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 7f6910ef35f..323be6cb850 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -58,26 +58,32 @@ def wrapped_fn(*args, **kwargs): def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "batch": - from verl.workers.reward_manager import BatchRewardManager + """ + Load and initialize a reward manager based on the configuration. - reward_manager_cls = BatchRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError + Returns: + An instance of the specified reward manager class. + """ + from verl.workers.reward_manager import get_reward_manager_cls + + # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: + # naive: NaiveRewardManager + # prime: PrimeRewardManager + # batch: BatchRewardManager + # dapo: DAPORewardManager + # Note(haibin.lin): For custom reward managers, please make sure they are imported and + # registered via `verl.workers.reward_manager.register` + # By default reward_manager is set to naive (NaiveRewardManager) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + # Try to get a custom reward function based on the configuration compute_score = get_custom_reward_fn(config) final_compute_score = compute_score @@ -86,11 +92,13 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): sandbox_url = sandbox_config.get("url") if sandbox_config else None if sandbox_url: sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) final_compute_score = partial(default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) else: final_compute_score = default_compute_score + # Instantiate and return the reward manager with the specified parameters return reward_manager_cls( tokenizer=tokenizer, num_examine=num_examine, diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py index 6d62fd86b8c..85e521dfe52 100644 --- a/verl/utils/import_utils.py +++ b/verl/utils/import_utils.py @@ -48,6 +48,15 @@ def is_sglang_available(): return sglang_spec is not None +@cache +def is_trl_available(): + try: + trl_spec = importlib.util.find_spec("trl") + except ModuleNotFoundError: + trl_spec = None + return trl_spec is not None + + def import_external_libs(external_libs=None): if external_libs is None: return diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py new file mode 100644 index 00000000000..805759d4795 --- /dev/null +++ b/verl/utils/kernel/__init__.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .kernels import BackwardEnum, set_backward_method +from .linear_cross_entropy import linear_cross_entropy + +__all__ = ["linear_cross_entropy", "set_backward_method", "BackwardEnum"] diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py new file mode 100644 index 00000000000..c41b87088a8 --- /dev/null +++ b/verl/utils/kernel/kernels.py @@ -0,0 +1,1391 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementations of the linear cross entropy with token entropy kernel. +""" + +import typing +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + + +@dataclass +class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ + + _None = 0 + _Sum = 1 + _Mean = 2 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if reduction == "none": + _enum = EntropyReductionEnum._None + elif reduction == "sum": + _enum = EntropyReductionEnum._Sum + elif reduction == "mean": + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if ce_reduction == 0: + _enum = EntropyReductionEnum._None + elif ce_reduction == 1: + _enum = EntropyReductionEnum._Sum + elif ce_reduction == 2: + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + return _enum + + +@dataclass +class BackwardEnum: + """ + Enum for the backward method. + """ + + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size + _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens + + +@dataclass +class Config: + _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N + _use_triton: bool = True + + +_config = Config() + + +def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ + global _config + _config._backward = backward_method + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + rcp_temperature: tl.float32, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over K dimension + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + # load the next block of hidden and weight + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( + # (pid_n + 1) * vocab_per_split, vocab_size))), + # other=0.0) + + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), other=0.0) + + # GEMM + logits = tl.dot(_hidden, _weight.trans(), logits) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # scale logits by temperature + logits *= rcp_temperature + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= offs_am < num_tokens + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue( + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + num_tokens, + num_splits, + global_max_ptr, + stride_global_max: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + global_entropy_ptr, + stride_global_entropy: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + _entropy_b = tl.load(entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, + num_splits, + reduced_max_ptr, + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_max_ptr, + stride_global_max: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load(reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _original_max = tl.load(original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _accu = tl.load(accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load(entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update( + num_tokens, logprobs_ptr, stride_logprobs: tl.int64, maximum_ptr, stride_maximum: tl.int64, accumulate_ptr, stride_accumulate: tl.int64, entropy_b_ptr, stride_entropy_b: tl.int64, entropy_ptr, stride_entropy: tl.int64, logprobs_scalar_ptr, reduction: int, BLOCK_SIZE_M: tl.constexpr +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction: typing.Optional[int] = 2, temperature: typing.Optional[float] = 1.0, dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """ + forward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = torch.cuda.Stream(hidden.device) + _dedicated_events = [torch.cuda.Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + if _config._use_triton: + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid]( + _rank, + hidden, + weight, + labels, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + hidden.stride(0), + hidden.stride(1), + weight.stride(0), + weight.stride(1), + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + _logprobs, + _logprobs.stride(0), + logprobs, + 1.0 / temperature, + ) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( + _max, + _max.stride(0), + _max.stride(1), + num_tokens, + num_splits, + maximum, + maximum.stride(0), + _accu, + _accu.stride(0), + _accu.stride(1), + accumulate, + accumulate.stride(0), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + _logprobs, + _logprobs.stride(0), + logprobs, + REDUCTION, + ) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + torch.cuda.current_stream().record_event(_dedicated_events[0]) + with torch.cuda.stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + ) + torch.cuda.current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid](num_tokens, _logprobs, _logprobs.stride(0), maximum, maximum.stride(0), accumulate, accumulate.stride(0), entropy_b, entropy_b.stride(0), entropy, entropy.stride(0), logprobs, REDUCTION) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n + d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + # tl.atomic_add(d_weight_ptrs, + # _d_weight, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) + _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) + tl.atomic_add(d_weight_ptrs, _d_weight, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size)) + + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) + tl.atomic_add(d_hidden_ptrs, _d_hidden, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens)) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_hidden( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_hidden + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_k = pid // num_pid_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + # iterate over vocab_size + d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): + offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over hidden_size to get logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits + d_logits *= rcp_temperature + + # calculate d_hidden + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) + _weight = tl.load(weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0) + d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) + + # write back + tl.store(d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, d_hidden, mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_weight( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + pid_n = pid % num_pid_n + pid_k = pid // num_pid_n + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): + offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) + _hidden = tl.load(hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0) + d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) + + # write back + tl.store(d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, d_weight, mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size)) + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), other=0.0) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store( + d_logits_ptrs, + d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=8), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits_split_N( + split_idx: int, + num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) + _weight = tl.load(weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), other=0.0) + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + # filter d_logits with mask + result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) + + tl.store(d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + should_return_fp32_grad: bool = False, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> typing.List[torch.Tensor]: + """ + backward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + d_hidden, d_weight = None, None + if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + else: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + if _config._backward == BackwardEnum._Total_Fuse_MN: + # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), + 1.0 / temperature, + ) + + elif _config._backward == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + if _config._use_triton: + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + torch.matmul(_d_logits, weight, out=d_hidden) + torch.matmul(_d_logits.T, hidden, out=d_weight) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + elif _config._backward == BackwardEnum._Split_Dlogits_N: + vocab_per_split = 9504 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + _d_logits = _d_logits[:, :vocab_right_bound].contiguous() + + if split_idx == 0: + torch.matmul(_d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden) + else: + d_hidden += torch.matmul(_d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]) + torch.matmul(_d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]) + + elif _config._backward == BackwardEnum._Split_Dlogits_M: + raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") + + return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py new file mode 100644 index 00000000000..8a7d43ec329 --- /dev/null +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,94 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import torch +import torch.distributed as dist + +from . import kernels + + +class LinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: typing.Optional[float] = 1.0, reduction: typing.Optional[str] = "none", dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """_summary_ + + Args: + ctx (_type_): _description_ + hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size) + weight (torch.Tensor): (vocab_size, hidden_size) + labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, ) + temperature (typing.Optional[float], optional): _description_. Defaults to 1.0. + reduction (typing.Optional[str], optional): _description_. Defaults to "none". + dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None. + + Returns: + typing.List[torch.Tensor]: _description_ + """ + + assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" + assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + original_hidden_shape = hidden.shape + if len(hidden.shape) != 2: + hidden = hidden.view(-1, hidden.shape[-1]) # (batch_size * num_tokens, hidden_size) + if len(labels.shape) != 1: + labels = labels.view(-1) + + logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward(hidden, weight, labels, REDUCTION, temperature, dist_process_group) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors + REDUCTION = ctx.REDUCTION + dist_process_group = ctx.dist_process_group + should_return_fp32_grad = ctx.should_return_fp32_grad + temperature = ctx.temperature + + d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, _maximum, _accumulate, _entropy_b, REDUCTION, should_return_fp32_grad, temperature, dist_process_group) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return (d_hidden, d_weight, None, None, None, None) + + +linear_cross_entropy = LinearCrossEntropy.apply diff --git a/verl/utils/model.py b/verl/utils/model.py index 11b944c6249..24482d06213 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -33,6 +33,7 @@ ) from verl.models.registry import ModelRegistry +from verl.utils.import_utils import is_trl_available class LambdaLayer(nn.Module): @@ -469,3 +470,70 @@ def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, pos parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig) return parallel_model + + +def patch_valuehead_model(model) -> None: + from types import MethodType + + from transformers import PreTrainedModel + + from trl import AutoModelForCausalLMWithValueHead + + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + + def can_generate(self): + return False + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + setattr(model, "_keys_to_ignore_on_save", ignore_modules) + setattr(model, "tie_weights", MethodType(tie_weights, model)) + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) + setattr(model, "can_generate", MethodType(can_generate, model)) + setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", [])) + + +def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): + from transformers import AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForVision2Seq + + try: + model = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + return model + except BaseException as e: + if not is_trl_available(): + raise RuntimeError(f"model({local_path}) is not a value head model, please install trl to make it valid") from e + + assert is_trl_available() + + from trl import AutoModelForCausalLMWithValueHead + + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): + module_class = AutoModelForVision2Seq + else: + module_class = AutoModelForCausalLM + ori_model = module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model) + patch_valuehead_model(model) + return model diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index aa441cf5ff2..535e6148dbf 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -18,6 +18,7 @@ import torch from torch import distributed as dist + from verl.utils.device import get_device_name diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index fba18c8e73b..b520f7c346d 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -27,7 +27,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from transformers import PreTrainedTokenizer -from verl.utils.device import get_torch_device, get_device_name + +from verl.utils.device import get_device_name, get_torch_device try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 8b81b66b86a..9eaf80e94e0 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -136,6 +136,7 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False extra_args = {} if self.use_fused_kernels: extra_args["temperature"] = temperature + extra_args["return_dict"] = True output = self.actor_module( input_ids=input_ids_rmpad, diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 5579f58d4c0..7e7b904b022 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -99,8 +99,13 @@ def _forward_micro_batch(self, micro_batch): **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating - values_rmpad = output.logits - values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values_rmpad = output[2].squeeze(0).unsqueeze(-1) + else: + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: @@ -117,7 +122,11 @@ def _forward_micro_batch(self, micro_batch): **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating - values = output.logits + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values = output[2] + else: + values = output.logits values = values[:, -response_length - 1 : -1].squeeze(-1) return values @@ -213,7 +222,7 @@ def update_critic(self, data: DataProto): micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu self.critic_optimizer.zero_grad() diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 8c2626bb690..7e1e50164e6 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -242,11 +242,15 @@ def _build_model_optimizer( _apply_liger_kernel_to_instance(model=actor_module) + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + apply_monkey_patch( model=actor_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 @@ -545,7 +549,6 @@ def init_model(self): optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", @@ -553,7 +556,6 @@ def init_model(self): OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding - self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) if self._is_actor: @@ -823,7 +825,7 @@ def _build_critic_model_optimizer(self, config): from torch import optim from torch.distributed.fsdp import MixedPrecision - from verl.utils.model import print_model_size + from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType use_shm = config.model.get("use_shm", False) @@ -864,11 +866,13 @@ def _build_critic_model_optimizer(self, config): warnings.simplefilter("ignore") critic_model_config.classifier_dropout = 0.0 critic_model_config.hidden_dropout = "0" - critic_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=critic_model_config, - trust_remote_code=config.model.get("trust_remote_code", False), + critic_model_config.summary_dropout_prob = 0.0 + + critic_module = load_valuehead_model( + local_path, + torch_dtype, + critic_model_config, + config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) @@ -1253,7 +1257,7 @@ def _forward_micro_batch(self, micro_batch): input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False) # prevent model thinks we are generating + output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False) reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py index 4a979f6944c..5d19e7cc1f5 100644 --- a/verl/workers/reward_manager/__init__.py +++ b/verl/workers/reward_manager/__init__.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .registry import get_reward_manager_cls, register # noqa: I001 from .batch import BatchRewardManager from .dapo import DAPORewardManager from .naive import NaiveRewardManager from .prime import PrimeRewardManager -__all__ = ["BatchRewardManager", "DAPORewardManager", "NaiveRewardManager", "PrimeRewardManager"] +# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies +__all__ = ["BatchRewardManager", "DAPORewardManager", "NaiveRewardManager", "PrimeRewardManager", "register", "get_reward_manager_cls"] diff --git a/verl/workers/reward_manager/batch.py b/verl/workers/reward_manager/batch.py index 570fdd71dea..fea8a385162 100644 --- a/verl/workers/reward_manager/batch.py +++ b/verl/workers/reward_manager/batch.py @@ -17,8 +17,10 @@ import torch from verl import DataProto +from verl.workers.reward_manager import register +@register("batch") class BatchRewardManager: def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): self.tokenizer = tokenizer diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index 399cdf05e09..2787361d950 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -18,8 +18,10 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register +@register("dapo") class DAPORewardManager: """The reward manager.""" diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 59ad618c4c1..ec9709ba8e6 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -18,16 +18,27 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register +@register("naive") class NaiveRewardManager: """The reward manager.""" def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: - self.tokenizer = tokenizer + """ + Initialize the NaiveRewardManager instance. + + Args: + tokenizer: The tokenizer used to decode token IDs into text. + num_examine: The number of batches of decoded responses to print to the console for debugging purpose. + compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. + reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to "data_source". + """ + self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or default_compute_score - self.reward_fn_key = reward_fn_key + self.reward_fn_key = reward_fn_key # Store the key for accessing the data source def __call__(self, data: DataProto, return_dict=False): """We will expand this function gradually based on the available datasets""" @@ -63,9 +74,7 @@ def __call__(self, data: DataProto, return_dict=False): response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] - data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get("extra_info", None) score = self.compute_score( diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index d1a68d85f96..b65eb5ff738 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -23,6 +23,7 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): @@ -94,7 +95,7 @@ def run_reward_scoring(evaluation_func, completions, references, tasks, extra_in finally: loop.close() - +@register("prime") class PrimeRewardManager: """ The Reward Manager used in https://github.com/PRIME-RL/PRIME diff --git a/verl/workers/reward_manager/registry.py b/verl/workers/reward_manager/registry.py new file mode 100644 index 00000000000..5c894437524 --- /dev/null +++ b/verl/workers/reward_manager/registry.py @@ -0,0 +1,45 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['register', "get_reward_manager_cls"] + +REWARD_MANAGER_REGISTRY = {} + +def register(name): + """Decorator to register a reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + """ + def decorator(cls): + if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls: + raise ValueError(f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}") + REWARD_MANAGER_REGISTRY[name] = cls + return cls + return decorator + +def get_reward_manager_cls(name): + """Get the reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + + Returns: + `(type)`: The reward manager class. + """ + if name not in REWARD_MANAGER_REGISTRY: + raise ValueError(f"Unknown reward manager: {name}") + return REWARD_MANAGER_REGISTRY[name] \ No newline at end of file