diff --git a/recipe/dapo/test_dapo_gptoss_20b_megatron.sh b/recipe/dapo/test_dapo_gptoss_20b_megatron.sh new file mode 100644 index 00000000000..021a8b3478a --- /dev/null +++ b/recipe/dapo/test_dapo_gptoss_20b_megatron.sh @@ -0,0 +1,251 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +################################################### document for gptoss ################################################### + +####################### running environment: ####################### +# option 1: use a pre-built docker image dedicated for gptoss: `docker://iseekyan/verl:nemo.gptoss_vllm0.11.0`, which is +# built upon nemo's dedicated image, see Dockerfile at https://github.com/volcengine/verl/blob/main/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.vllm011.mcore_gpt-oss +# +# option 2: self build TE>=2.8 with CUDNN>=9.13.1, megatron with branch `core_dev_r0.15.0`, latest vllm or sglang +# you can modify the dockerfile to build the image, see Dockerfile at https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.vllm or https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.sglang + +####################### before training: ####################### +# # install matched mbridge version +# pip uninstall -y mbridge && pip install git+https://github.com/ISEEKYAN/mbridge@gpt-oss + +# # convert gptoss to bf16 +cat > get_model.py << EOF +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config + +model_id = "openai/gpt-oss-20b" +output_dir = "$HOME/models/gpt-oss-20b-bf16" + +quantization_config = Mxfp4Config(dequantize=True) +model_kwargs = dict( + attn_implementation="eager", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + use_cache=False, + device_map="auto", +) + +model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + +# Patch config with custom attribute before saving +model.config.attn_implementation = "eager" + +model.save_pretrained(output_dir) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.save_pretrained(output_dir) +EOF + +python get_model.py + +####################### specific training config: ####################### + +GPT_OSS_CONFIG=( + # only support mbridge for gptoss + actor_rollout_ref.actor.megatron.use_mbridge=True + # for now (latest TE=2.10), gptoss's optimized attn kernel is not supported for thd format, so we use bshd format here + # when bshd format is used, we need to pad the input_ids to the longest sequence length + # so we recommend to disable dynamic batch size and set micro batch size to 1 to avoid paddings + # but it is ok to try with micro_batch_size>1 + actor_rollout_ref.actor.megatron.use_remove_padding=False +) +use_dynamic_bsz=False # recommended but not necessary + +################################################### quick config ################################################### + +rollout_mode="sync" +rollout_name="vllm" # sglang or vllm +return_raw_chat="False" +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi +dtype="bfloat16" # ["bfloat16", "float16"] + +project_name='DAPO' +exp_name='gptoss' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/gpt-oss-20b"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=4 +train_tp=4 +EP=8 +ETP=1 +train_pp=1 + +################################################### start of config ################################################### + + +DATA=( + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + data.prompt_key=prompt + data.return_raw_chat=$return_raw_chat + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} +) + +REWARD_MODEL=( + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + reward_model.reward_manager=dapo +) + +PERF_OPT=( + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True + actor_rollout_ref.model.use_fused_kernels=False + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True +) + +ACTOR=( + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=10 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.optim.clip_grad=1.0 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.megatron.param_offload=${offload} + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} + actor_rollout_ref.actor.megatron.grad_offload=${offload} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} +) + +ROLLOUT=( + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.mode=${rollout_mode} + actor_rollout_ref.rollout.dtype=${dtype} + actor_rollout_ref.rollout.gpu_memory_utilization=0.70 + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.n=${n_resp_per_prompt} +) + +TRAINER=( + trainer.logger=['console','wandb'] + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.n_gpus_per_node=8 + trainer.nnodes="${NNODES}" + trainer.val_before_train=False + trainer.test_freq=10 + trainer.save_freq=-1 + trainer.total_epochs=10 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.resume_mode=auto + trainer.log_val_generations=10 +) + +FORWARD_ONLY_SETS=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} +) + +MODEL=( + actor_rollout_ref.model.path="${MODEL_PATH}" +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) +################################################### start script ################################################### +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + -- python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REWARD_MODEL[@]}" \ + "${PERF_OPT[@]}" \ + "${TRAINER[@]}" \ + "${GPT_OSS_CONFIG[@]}" \ + "${FORWARD_ONLY_SETS[@]}" \ \ No newline at end of file diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index c235917cc96..3a9d6bb4aba 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -18,10 +18,14 @@ from verl.utils.megatron_utils import unwrap_model from .util import ( + postprocess_bshd, + postprocess_bshd_no_padding, postprocess_packed_seqs, - postprocess_packed_seqs_no_padding, + postprocess_thd_no_padding, + preprocess_bshd, + preprocess_bshd_no_padding, preprocess_packed_seqs, - preprocess_packed_seqs_no_padding, + preprocess_thd_no_padding, ) @@ -35,12 +39,15 @@ def model_forward( logits_processor=None, logits_processor_args: dict = None, value_model=False, + data_format: str = "thd", ): """Forward pass for models with sequence packing.""" + assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" pre_process = ( unwrap_model(model).pre_process if not vision_model else False ) # vision model does not need pre_process, because we pack the input_ids to thd in the forward function post_process = unwrap_model(model).post_process + sp = unwrap_model(model).config.sequence_parallel fp8 = unwrap_model(model).config.fp8 use_fp8_padding = fp8 in ["e4m3", "hybrid"] @@ -55,44 +62,82 @@ def model_forward( model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( - input_ids, attention_mask, pre_process=pre_process, use_fp8_padding=use_fp8_padding - ) - input_ids_rmpad = input_ids_rmpad.contiguous() - - input_args = dict( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids if not vision_model else None, # vision models will calculate position_ids - packed_seq_params=packed_seq_params, - **model_kwargs, - ) - - if vision_model: - # workaround for supporting sequence packing with context parallelism - # cp split with sequence packing will make model lose vision token information, so we need to keep - # the original input_ids and pack them after vision embedding is calculated, - # cooporate with mbridge - input_args["input_ids"] = input_ids - input_args["attention_mask"] = attention_mask + if data_format == "thd": + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=pre_process, use_fp8_padding=use_fp8_padding + ) + input_ids_rmpad = input_ids_rmpad.contiguous() + + input_args = dict( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids if not vision_model else None, # vision models will calculate position_ids + packed_seq_params=packed_seq_params, + **model_kwargs, + ) - output_orig = model(**input_args) - if post_process and logits_processor is not None: - args = { - k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] - for k, v in logits_processor_args.items() - } - output_dict = logits_processor(output_orig, **args) - output = { - k: postprocess_packed_seqs( - v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + if vision_model: + # workaround for supporting sequence packing with context parallelism + # cp split with sequence packing will make model lose vision token information, so we need to keep + # the original input_ids and pack them after vision embedding is calculated, + # cooporate with mbridge + input_args["input_ids"] = input_ids + input_args["attention_mask"] = attention_mask + + output_orig = model(**input_args) + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process ) - for k, v in output_dict.items() - } - else: - output = postprocess_packed_seqs( - output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + elif data_format == "bshd": + """ + data_format: "thd" or "bshd", default is "thd", + why we need this? + for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. + When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, + so it is recommended to disable dynamic batch size and set batch size to 1 + """ + assert not vision_model, "vision model does not support bshd format" + assert fp8 is None, "fp8 is not supported for bshd format yet" + + batch_size, sequence_length = attention_mask.shape[:2] + new_input_ids, new_attention_mask, new_position_ids = preprocess_bshd( + input_ids, attention_mask, position_ids, sequence_parallel=sp, pre_process=pre_process + ) + output_orig = model( + input_ids=new_input_ids, + position_ids=new_position_ids, + attention_mask=new_attention_mask, + **model_kwargs, ) + if post_process and logits_processor is not None: + args = { + k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=sp, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_bshd( + v, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_bshd( + output_orig, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) if value_model and post_process: output = output[..., 0] return output @@ -107,8 +152,11 @@ def gptmodel_forward_no_padding( logits_processor=None, logits_processor_args: dict = None, value_model=False, + data_format: str = "thd", ): """Default forward pass for GPT models with optional sequence packing.""" + + assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" pre_process = unwrap_model(model).pre_process post_process = unwrap_model(model).post_process @@ -117,34 +165,67 @@ def gptmodel_forward_no_padding( model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs: model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "pixel_values_videos" in multi_modal_inputs: + model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) + if "video_grid_thw" in multi_modal_inputs: + model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) batch_size = input_ids.shape[0] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process) - input_ids_rmpad = input_ids_rmpad.contiguous() - output_orig = model( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=None, - packed_seq_params=packed_seq_params, - **model_kwargs, - ) - - if post_process and logits_processor is not None: - args = { - k: preprocess_packed_seqs_no_padding(v, pre_process=True, need_roll=(k == "label"))[0] - for k, v in logits_processor_args.items() - } - output_dict = logits_processor(output_orig, **args) - output = { - k: postprocess_packed_seqs_no_padding( - v, packed_seq_params, input_ids, batch_size, post_process=post_process + if data_format == "thd": + input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding(input_ids, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=None, + packed_seq_params=packed_seq_params, + **model_kwargs, + ) + + if post_process and logits_processor is not None: + args = { + k: preprocess_thd_no_padding(v, pre_process=True, need_roll=(k == "label"))[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_thd_no_padding(v, packed_seq_params, input_ids, batch_size, post_process=post_process) + for k, v in output_dict.items() + } + else: + output = postprocess_thd_no_padding( + output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process ) - for k, v in output_dict.items() - } else: - output = postprocess_packed_seqs_no_padding( - output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process + """ + data_format: "thd" or "bshd", default is "thd", + why we need this? + for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. + When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, + so it is recommended to disable dynamic batch size and set batch size to 1 + """ + + input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding( + input_ids, pre_process=pre_process + ) + output_orig = model( + input_ids=input_ids_bshd, + attention_mask=attention_mask_bshd, + position_ids=position_ids_bshd, + **model_kwargs, ) + if post_process and logits_processor is not None: + args = { + k: preprocess_bshd_no_padding(v, pre_process=True, need_roll=(k == "label"))[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_bshd_no_padding(v, attention_mask_bshd, post_process=post_process) + for k, v in output_dict.items() + } + else: + output = postprocess_bshd_no_padding(output_orig, attention_mask_bshd, post_process=post_process) if value_model and post_process: # output = output[..., 0] diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index 48515ffc22d..d8c7b2cfa86 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -69,6 +69,7 @@ class SupportedModel(Enum): QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification" QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" QWEN3_VL = "Qwen3VLForConditionalGeneration" + GPT_OSS = "GptOssForCausalLM" # Registry for model configuration converters @@ -115,6 +116,7 @@ class SupportedModel(Enum): SupportedModel.DEEPSEEK_V3: model_forward_gen(), SupportedModel.GLM4_MOE: model_forward_gen(), SupportedModel.QWEN3_TOKEN_CLASSIFICATION: model_forward_gen(), + SupportedModel.GPT_OSS: model_forward_gen(), } # Registry for model forward functions @@ -133,6 +135,7 @@ class SupportedModel(Enum): SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding, SupportedModel.GLM4_MOE: gptmodel_forward_no_padding, SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, + SupportedModel.GPT_OSS: gptmodel_forward_no_padding, } # Registry for model forward functions @@ -150,6 +153,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3_MOE: fused_forward_model_gen(), SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(), SupportedModel.GLM4_MOE: fused_forward_model_gen(), + SupportedModel.GPT_OSS: fused_forward_model_gen(), } # Registry for model weight converters diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index 6ca270c6fb6..65fdbbfe2bd 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -175,7 +175,109 @@ def postprocess_packed_seqs( return output_new -def preprocess_packed_seqs_no_padding( +def preprocess_bshd( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + sequence_parallel: bool = False, + pre_process: bool = True, +): + """ + Remove left padding from input_ids, attention_mask and position_ids + return new_input_ids, new_attention_mask, new_position_ids + """ + assert attention_mask.ndim == 2 + assert position_ids.ndim == 2 + cp_size = mpu.get_context_parallel_world_size() + assert cp_size == 1, "Context parallel size without seq_pack is not supported" + batch_size = input_ids.shape[0] + shape = list(input_ids.shape) # batch_size, seq_len,... + seq_lens = attention_mask.sum(dim=1) + seq_len = seq_lens.max().item() + if sequence_parallel: + sp_world_size = mpu.get_tensor_model_parallel_world_size() + pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size + seq_len = seq_len + pad_size + shape[1] = seq_len + if pre_process: + new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) + new_attention_mask = torch.zeros( + dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + ) + new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) + for i in range(batch_size): + if pre_process: + new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] + new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] + new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] + if pre_process: + return new_input_ids, new_attention_mask, new_position_ids + else: + return input_ids, new_attention_mask, new_position_ids + + +def postprocess_bshd( + result, + attention_mask: torch.Tensor, + original_attention_mask: torch.Tensor, + origin_seqlen: int, + post_process: bool = True, +): + """ + Recover left padding from result + return result + """ + if not post_process: + return result + shape = list(result.shape) + batch_size = shape[0] + shape[1] = origin_seqlen + new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) + for i in range(batch_size): + new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] + return new_result + + +def postprocess_packed_seqs_for_dict_output( + labels_mask: torch.Tensor, + output: CausalLMOutputForPPO, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> dict[str, torch.Tensor]: + """_summary_ + For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. + This function post-processes each tensor in the output dictionary. + Args: + output (CausalLMOutputForPPO): _description_ + packed_seq_params (PackedSeqParams): _description_ + attention_mask (torch.Tensor): _description_ + batch_size (int): _description_ + seq_len (int): _description_ + post_process (bool, optional): _description_. Defaults to True. + Returns: + CausalLMOutputForPPO: _description_ + """ + ret = {} + output.entropy = output.entropy.view(1, -1) + output.log_probs = output.log_probs.view(1, -1) + output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) + ret["entropy"] = postprocess_packed_seqs( + output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + ret["log_probs"] = postprocess_packed_seqs( + output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + return ret + + +### No padding versions for model engine +### inputs are nested tensors + + +def preprocess_thd_no_padding( input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False ) -> tuple[torch.Tensor, PackedSeqParams]: """ @@ -274,7 +376,7 @@ def preprocess_packed_seqs_no_padding( return input_ids, packed_seq_params -def postprocess_packed_seqs_no_padding( +def postprocess_thd_no_padding( output: torch.Tensor, packed_seq_params: PackedSeqParams, input_ids: torch.Tensor, @@ -338,99 +440,54 @@ def postprocess_packed_seqs_no_padding( return output_new_tensor -def remove_left_padding( - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - sequence_parallel: bool = False, - pre_process: bool = True, -): +def preprocess_bshd_no_padding(input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False): """ - Remove left padding from input_ids, attention_mask and position_ids - return new_input_ids, new_attention_mask, new_position_ids + Preprocess bshd sequences + return "input_ids, attention_mask, position_ids" """ - assert attention_mask.ndim == 2 - assert position_ids.ndim == 2 cp_size = mpu.get_context_parallel_world_size() - assert cp_size == 1, "Context parallel size without seq_pack is not supported" + # TODO: support context parallel size > 1 + assert cp_size == 1, "Context parallel size without bshd is not supported yet" + batch_size = input_ids.shape[0] - shape = list(input_ids.shape) # batch_size, seq_len,... - seq_lens = attention_mask.sum(dim=1) - seq_len = seq_lens.max().item() - if sequence_parallel: + seqlens_in_batch = input_ids.offsets().diff() + max_seqlen = seqlens_in_batch.max().item() + if mpu.get_tensor_model_parallel_world_size() > 1: sp_world_size = mpu.get_tensor_model_parallel_world_size() - pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size - seq_len = seq_len + pad_size - shape[1] = seq_len - if pre_process: - new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) - new_attention_mask = torch.zeros( - dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) - ) - new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) + pad_size = (sp_world_size - max_seqlen % sp_world_size) % sp_world_size + max_seqlen = max_seqlen + pad_size + + attention_mask = torch.zeros(batch_size, max_seqlen, dtype=torch.bool, device=input_ids.device) + input_ids_bshd = torch.zeros(batch_size, max_seqlen, dtype=input_ids.dtype, device=input_ids.device) for i in range(batch_size): - if pre_process: - new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] - new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] - new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] - if pre_process: - return new_input_ids, new_attention_mask, new_position_ids - else: - return input_ids, new_attention_mask, new_position_ids + attention_mask[i, : seqlens_in_batch[i]] = True + input_ids_bshd[i, : seqlens_in_batch[i]] = input_ids[i] + position_ids = torch.arange(max_seqlen, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids_bshd) + if need_roll: + input_ids_bshd = torch.roll(input_ids_bshd, shifts=-1, dims=1) + return input_ids_bshd, attention_mask, position_ids -def recover_left_padding( - result, + +def postprocess_bshd_no_padding( + output: torch.Tensor, attention_mask: torch.Tensor, - original_attention_mask: torch.Tensor, - origin_seqlen: int, post_process: bool = True, -): +) -> torch.Tensor: """ - Recover left padding from result - return result + Postprocess bshd sequences """ if not post_process: - return result - shape = list(result.shape) - batch_size = shape[0] - shape[1] = origin_seqlen - new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) + return output + + batch_size = output.shape[0] + output_new = [] + for i in range(batch_size): - new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] - return new_result + mask = attention_mask[i].bool() + output_new.append(output[i][mask]) + output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) -def postprocess_packed_seqs_for_dict_output( - labels_mask: torch.Tensor, - output: CausalLMOutputForPPO, - packed_seq_params: PackedSeqParams, - attention_mask: torch.Tensor, - batch_size: int, - seq_len: int, - post_process: bool = True, -) -> dict[str, torch.Tensor]: - """_summary_ - For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. - This function post-processes each tensor in the output dictionary. - Args: - output (CausalLMOutputForPPO): _description_ - packed_seq_params (PackedSeqParams): _description_ - attention_mask (torch.Tensor): _description_ - batch_size (int): _description_ - seq_len (int): _description_ - post_process (bool, optional): _description_. Defaults to True. - Returns: - CausalLMOutputForPPO: _description_ - """ - ret = {} - output.entropy = output.entropy.view(1, -1) - output.log_probs = output.log_probs.view(1, -1) - output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) - ret["entropy"] = postprocess_packed_seqs( - output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - ret["log_probs"] = postprocess_packed_seqs( - output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - return ret + return output_new_tensor diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 7bf04848ea4..21d299fc8a3 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -54,6 +54,7 @@ actor_rollout_ref: override_mcore_model_config: {} use_mbridge: false vanilla_mbridge: true + use_remove_padding: true forward_only: false dtype: bfloat16 _target_: verl.workers.config.McoreActorConfig @@ -178,6 +179,7 @@ actor_rollout_ref: override_mcore_model_config: {} use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} forward_only: true dtype: bfloat16 _target_: verl.workers.config.McoreActorConfig @@ -414,6 +416,7 @@ critic: override_mcore_model_config: {} use_mbridge: false vanilla_mbridge: true + use_remove_padding: true forward_only: false dtype: bfloat16 _target_: verl.workers.config.McoreCriticConfig @@ -546,6 +549,7 @@ reward_model: override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} dtype: bfloat16 load_weight: true algorithm: diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index f8af11fb3ae..84601f5a3f5 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -80,6 +80,9 @@ use_mbridge: False # oc.select: default val for ref.megatron.vanilla_mbridge vanilla_mbridge: True +# whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length +use_remove_padding: True + # whether to use forward only forward_only: False diff --git a/verl/trainer/config/ref/megatron_ref.yaml b/verl/trainer/config/ref/megatron_ref.yaml index 944c07b698d..ca1fbb3c073 100644 --- a/verl/trainer/config/ref/megatron_ref.yaml +++ b/verl/trainer/config/ref/megatron_ref.yaml @@ -17,6 +17,7 @@ megatron: override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1} virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null} diff --git a/verl/trainer/config/reward_model/megatron_reward_model.yaml b/verl/trainer/config/reward_model/megatron_reward_model.yaml index a787bbd9bb3..ea585075e57 100644 --- a/verl/trainer/config/reward_model/megatron_reward_model.yaml +++ b/verl/trainer/config/reward_model/megatron_reward_model.yaml @@ -67,6 +67,9 @@ megatron: # Whether to use mbridge instead of Megatron-Bridge vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + # Whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + dtype: bfloat16 # Whether to load weights (default True) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 4e9e1d62dd6..6999363501c 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -578,11 +578,12 @@ def logits_processor(logits, label, label_mask): ret = {} if calculate_entropy: logits_bak = logits.clone() - logger.warning_once( - "For memory-efficient computation, enable fused kernels via " - "`actor_rollout_ref.model.use_fused_kernels=True`. " - "The current `clone()` operation ensures correctness but increases memory usage." - ) + # # disable the hint until the fused_kernel is optimized for triton>=3.3 + # logger.warning_once( + # "For memory-efficient computation, enable fused kernels via " + # "`actor_rollout_ref.model.use_fused_kernels=True`. " + # "The current `clone()` operation ensures correctness but increases memory usage." + # ) entropy = vocab_parallel_entropy(logits) ret["entropy"] = entropy else: @@ -601,6 +602,7 @@ def logits_processor(logits, label, label_mask): multi_modal_inputs=multi_modal_inputs, logits_processor=logits_processor, logits_processor_args=logits_processor_args, + data_format="thd" if self.config.megatron.use_remove_padding else "bshd", ) if forward_only: diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index f0cebf5ad76..1399d9961b2 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -83,6 +83,7 @@ class McoreEngineConfig(EngineConfig): override_mcore_model_config: dict[str, Any] = field(default_factory=dict) use_mbridge: bool = False vanilla_mbridge: bool = True + use_remove_padding: bool = True strategy: str = "megatron" def __post_init__(self) -> None: diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 19e768a4678..deaebe3e618 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -109,7 +109,6 @@ def _build_tf_config(self): self.dtype = PrecisionType.to_dtype(self.param_dtype) override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) - tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config) use_mbridge = self.engine_config.use_mbridge self.provider = None @@ -163,6 +162,7 @@ def _build_tf_config(self): self.bridge = bridge else: self.bridge = None + tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config) if not self.bridge: self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) @@ -666,12 +666,13 @@ def logits_processor(logits, label): ret = {} if calculate_entropy: logits_bak = logits.clone() - if torch.distributed.get_rank() == 0: - logger.warning_once( - "For memory-efficient computation, enable fused kernels via " - "`actor_rollout_ref.model.use_fused_kernels=True`. " - "The current `clone()` operation ensures correctness but increases memory usage." - ) + # # disable the hint until the fused_kernel is optimized for triton>=3.3 + # if torch.distributed.get_rank() == 0: + # logger.warning_once( + # "For memory-efficient computation, enable fused kernels via " + # "`actor_rollout_ref.model.use_fused_kernels=True`. " + # "The current `clone()` operation ensures correctness but increases memory usage." + # ) entropy = vocab_parallel_entropy(logits) ret["entropy"] = entropy else: @@ -689,6 +690,7 @@ def logits_processor(logits, label): multi_modal_inputs, logits_processor=logits_processor, logits_processor_args=logits_processor_args, + data_format="thd" if self.engine_config.use_remove_padding else "bshd", ) return output, partial(postprocess_micro_batch_func, data=batch)