diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 39372a3fff0..cf77df2229b 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -190,6 +190,7 @@ jobs: MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + SP_SIZE=2 \ bash tests/e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang: diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index c49cdf83e70..7095b23c7e9 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -40,6 +40,7 @@ TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} # whether to save hf_model SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} FSDP_SIZE=${FSDP_SIZE:--1} +SP_SIZE=${SP_SIZE:-1} if [ "${SAVE_HF_MODEL}" = "True" ]; then CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']" @@ -92,6 +93,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 6c513ecc1e7..d4184e5b1e1 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -25,12 +25,12 @@ from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_utils import PreTrainedModel -from verl.models.transformers.llama import forward_for_ppo from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_world_size, + slice_input_tensor, ) @@ -107,6 +107,37 @@ def _ulysses_flash_attention_forward( return attn_output +def patch_vlm_for_ulysses_input_slicing(model_class: type): + """ + Applies a monkey patch to the forward method of a given model class + to enable Ulysses sequence parallelism input slicing. + """ + + def _create_ulysses_wrapped_decoder_forward(original_forward): + def ulysses_wrapped_decoder_forward(self, *args, **kwargs): + inputs_embeds = kwargs.get("inputs_embeds") + call_kwargs = kwargs.copy() + + current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + slice_now = inputs_embeds is not None and current_ulysses_sp_size > 1 and getattr(self, "_needs_initial_slice", True) + if slice_now: + call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) + self._needs_initial_slice = False + try: + return original_forward(self, *args, **call_kwargs) + finally: + if slice_now: + self._needs_initial_slice = True + + return ulysses_wrapped_decoder_forward + + original_forward = model_class.forward + wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward) + model_class.forward = wrapped_forward + print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") + + def apply_monkey_patch( model: PreTrainedModel, ulysses_sp_size: int = 1, @@ -134,6 +165,14 @@ def apply_monkey_patch( Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2.5VL") + if ulysses_sp_size > 1: + if is_transformers_version_in_range(min_version="4.52.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) + if use_fused_kernels: from verl.models.transformers.qwen2_5_vl import forward_for_ppo @@ -153,6 +192,14 @@ def apply_monkey_patch( Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2VL") + if ulysses_sp_size > 1: + if is_transformers_version_in_range(min_version="4.52.0"): + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) + else: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel + patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) + if use_fused_kernels: from verl.models.transformers.qwen2_vl import forward_for_ppo @@ -179,12 +226,21 @@ def apply_monkey_patch( @lru_cache -def is_transformers_version_in_range(min_version: str, max_version: str) -> bool: +def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool: try: # Get the installed version of the transformers library - transformers_version = importlib.metadata.version("transformers") + transformers_version_str = importlib.metadata.version("transformers") except importlib.metadata.PackageNotFoundError as e: raise ModuleNotFoundError("The `transformers` package is not installed.") from e - # Check if the version is within the specified range - return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version) + transformers_version = version.parse(transformers_version_str) + + lower_bound_check = True + if min_version is not None: + lower_bound_check = version.parse(min_version) <= transformers_version + + upper_bound_check = True + if max_version is not None: + upper_bound_check = transformers_version <= version.parse(max_version) + + return lower_bound_check and upper_bound_check diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py index a33293364f1..3670b1d20a1 100644 --- a/verl/utils/ulysses.py +++ b/verl/utils/ulysses.py @@ -268,6 +268,22 @@ def gather_outpus_and_unpad( x = _unpad_tensor(x, unpad_dim, padding_size) return x +def ulysses_pad( + input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 +): + if position_ids_rmpad is not None: + assert position_ids_rmpad.size(0) == 1 + assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1) + if sp_size <= 1: + return input_ids_rmpad, position_ids_rmpad, 0 + _, total_seq_len = input_ids_rmpad.shape + pad_size = (sp_size - total_seq_len % sp_size) % sp_size + if pad_size > 0: + input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) + if position_ids_rmpad is not None: + pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) + position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) + return input_ids_rmpad, position_ids_rmpad, pad_size def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1): """ @@ -288,18 +304,9 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmp torch.Tensor: padded and sliced position_ids int: pad size """ - if position_ids_rmpad is not None: - assert position_ids_rmpad.size(0) == 1 - assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1) - if sp_size <= 1: - return input_ids_rmpad, position_ids_rmpad, 0 - _, total_seq_len = input_ids_rmpad.shape - pad_size = (sp_size - total_seq_len % sp_size) % sp_size - if pad_size > 0: - input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) - if position_ids_rmpad is not None: - pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) - position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, position_ids_rmpad, sp_size + ) input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) if position_ids_rmpad is not None: position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 0f1c7c562bf..81d701ebb92 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -35,7 +35,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs +from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs, ulysses_pad from verl.workers.actor import BasePPOActor if is_cuda_available: @@ -108,11 +108,20 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) + is_vlm_model = "multi_modal_inputs" in micro_batch + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( input_ids_rmpad_rolled, position_ids_rmpad=None,