Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ jobs:
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
SP_SIZE=2 \
bash tests/e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_sglang:
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}
# whether to save hf_model
SAVE_HF_MODEL=${SAVE_HF_MODEL:-False}
FSDP_SIZE=${FSDP_SIZE:--1}
SP_SIZE=${SP_SIZE:-1}

if [ "${SAVE_HF_MODEL}" = "True" ]; then
CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']"
Expand Down Expand Up @@ -92,6 +93,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand Down
66 changes: 61 additions & 5 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel

from verl.models.transformers.llama import forward_for_ppo
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
slice_input_tensor,
)


Expand Down Expand Up @@ -107,6 +107,37 @@ def _ulysses_flash_attention_forward(
return attn_output


def patch_vlm_for_ulysses_input_slicing(model_class: type):
"""
Applies a monkey patch to the forward method of a given model class
to enable Ulysses sequence parallelism input slicing.
"""

def _create_ulysses_wrapped_decoder_forward(original_forward):
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
call_kwargs = kwargs.copy()

current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()

slice_now = inputs_embeds is not None and current_ulysses_sp_size > 1 and getattr(self, "_needs_initial_slice", True)
if slice_now:
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
self._needs_initial_slice = False
try:
return original_forward(self, *args, **call_kwargs)
finally:
if slice_now:
self._needs_initial_slice = True

return ulysses_wrapped_decoder_forward

original_forward = model_class.forward
wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward)
model_class.forward = wrapped_forward
print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.")


def apply_monkey_patch(
model: PreTrainedModel,
ulysses_sp_size: int = 1,
Expand Down Expand Up @@ -134,6 +165,14 @@ def apply_monkey_patch(
Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")

if ulysses_sp_size > 1:
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
else:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)

if use_fused_kernels:
from verl.models.transformers.qwen2_5_vl import forward_for_ppo

Expand All @@ -153,6 +192,14 @@ def apply_monkey_patch(
Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in Qwen2VL")

if ulysses_sp_size > 1:
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
else:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)

if use_fused_kernels:
from verl.models.transformers.qwen2_vl import forward_for_ppo

Expand All @@ -179,12 +226,21 @@ def apply_monkey_patch(


@lru_cache
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
transformers_version_str = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError as e:
raise ModuleNotFoundError("The `transformers` package is not installed.") from e

# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
transformers_version = version.parse(transformers_version_str)

lower_bound_check = True
if min_version is not None:
lower_bound_check = version.parse(min_version) <= transformers_version

upper_bound_check = True
if max_version is not None:
upper_bound_check = transformers_version <= version.parse(max_version)

return lower_bound_check and upper_bound_check
31 changes: 19 additions & 12 deletions verl/utils/ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,22 @@ def gather_outpus_and_unpad(
x = _unpad_tensor(x, unpad_dim, padding_size)
return x

def ulysses_pad(
input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1
):
if position_ids_rmpad is not None:
assert position_ids_rmpad.size(0) == 1
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
if sp_size <= 1:
return input_ids_rmpad, position_ids_rmpad, 0
_, total_seq_len = input_ids_rmpad.shape
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
if pad_size > 0:
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
if position_ids_rmpad is not None:
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
return input_ids_rmpad, position_ids_rmpad, pad_size

def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1):
"""
Expand All @@ -288,18 +304,9 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmp
torch.Tensor: padded and sliced position_ids
int: pad size
"""
if position_ids_rmpad is not None:
assert position_ids_rmpad.size(0) == 1
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
if sp_size <= 1:
return input_ids_rmpad, position_ids_rmpad, 0
_, total_seq_len = input_ids_rmpad.shape
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
if pad_size > 0:
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
if position_ids_rmpad is not None:
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
input_ids_rmpad, position_ids_rmpad, sp_size
)
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
if position_ids_rmpad is not None:
position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)
Expand Down
21 changes: 15 additions & 6 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs, ulysses_pad
from verl.workers.actor import BasePPOActor

if is_cuda_available:
Expand Down Expand Up @@ -108,11 +108,20 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False

# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
is_vlm_model = "multi_modal_inputs" in micro_batch
if is_vlm_model:
# vlm model's inputs will be sliced after embedding
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
else:
Comment on lines +112 to +119

This comment was marked as off-topic.

input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled,
position_ids_rmpad=None,
Expand Down
Loading