diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index 85a37ffdf0..9d69624cb7 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 85a37ffdf02edc07c0a7ac97cb9abcafcd0ac0ed +Subproject commit 9d69624cb75e46f06ddfadd9a726acecfcf8b064 diff --git a/3rdparty/Megatron-Bridge-workspace/setup.py b/3rdparty/Megatron-Bridge-workspace/setup.py index 06657bab31..9797c340de 100644 --- a/3rdparty/Megatron-Bridge-workspace/setup.py +++ b/3rdparty/Megatron-Bridge-workspace/setup.py @@ -33,7 +33,7 @@ "packaging", "tensorboard>=2.19.0", "torch", - "transformers>=4.51.3", + "transformers>=4.55.0", "typing-extensions", "rich", "wandb>=0.19.10", diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml index 2d39d9cd7f..a5da6ed98f 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml @@ -3,12 +3,3 @@ checkpointing: checkpoint_dir: results/clevr_grpo policy: max_total_sequence_length: 3072 -env: - refcoco: - reward_functions: - - name: format - weight: 0.1 - - name: bbox_giou - weight: 0.9 - kwargs: - giou_penalty_thres: 1.0 diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml new file mode 100644 index 0000000000..c8657ef818 --- /dev/null +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml @@ -0,0 +1,25 @@ +defaults: ../../vlm_grpo_3B.yaml +checkpointing: + checkpoint_dir: results/clevr_grpo +policy: + max_total_sequence_length: 3072 + dtensor_cfg: + enabled: false + dynamic_batching: + enabled: false + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + optimizer: null + megatron_cfg: + enabled: true + empty_unused_memory_level: 1 + optimizer: + lr: 5.0e-07 + min_lr: 5.0e-08 + scheduler: + lr_warmup_iters: 50 + lr_warmup_init: 5.0e-08 + distributed_data_parallel_config: + overlap_grad_reduce: false +logger: + wandb: + name: vlm-grpo-3b-megatron diff --git a/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml.disabled similarity index 100% rename from examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml rename to examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml.disabled diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 3c61241714..460bc3474d 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -58,6 +58,70 @@ policy: context_parallel_size: 1 custom_parallel_plan: null + megatron_cfg: + enabled: false + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + defer_fp32_logits: null + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length # responses are sorted by sequence length and bucketed into microbatches with a total @@ -76,6 +140,10 @@ policy: sequence_packing: enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 optimizer: name: "torch.optim.AdamW" diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml new file mode 100644 index 0000000000..ca7b03301e --- /dev/null +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -0,0 +1,200 @@ +grpo: + num_prompts_per_step: 8 + num_generations_per_prompt: 16 + max_rollout_turns: 1 + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + token_level_loss: true +checkpointing: + enabled: true + checkpoint_dir: results/clevr_grpo_${policy.model_name} + metric_name: val_reward + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null +policy: + model_name: Qwen/Qwen2.5-VL-3B-Instruct + tokenizer: + name: ${policy.model_name} + train_global_batch_size: 128 + train_micro_batch_size: 1 + generation_batch_size: 32 + logprob_batch_size: 4 + max_total_sequence_length: 2048 + precision: bfloat16 + dtensor_cfg: + _v2: true + enabled: false + cpu_offload: false + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + max_grad_norm: 1.0 + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + optimizer: null + scheduler: + - name: torch.optim.lr_scheduler.LinearLR + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: torch.optim.lr_scheduler.ConstantLR + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: + - 50 + generation: + backend: vllm + max_new_tokens: 1024 + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: false + enable_expert_parallel: false + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + megatron_cfg: + enabled: true + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: Qwen2ForCausalLM + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + moe_permute_fusion: false + apply_rope_fusion: true + optimizer: + optimizer: adam + lr: 2.0e-07 + min_lr: 2.0e-07 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: float32 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1.0e-08 + sgd_momentum: 0.9 + use_distributed_optimizer: true + use_precision_aware_optimizer: true + clip_grad: ${policy.max_grad_norm} + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: constant + lr_decay_style: constant + lr_decay_iters: 1000 + lr_warmup_iters: 50 + lr_warmup_init: 2.0e-08 + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: false + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: optim_grads_params +data: + max_input_seq_length: ${policy.max_total_sequence_length} + prompt_file: examples/prompts/clevr_cogent_cot.txt + system_prompt_file: null + dataset_name: clevr-cogent + split: trainA + shuffle: true +env: + clevr-cogent: + num_workers: 8 + reward_functions: + - name: format + weight: 0.2 + - name: exact_alnum + weight: 0.8 + geometry3k: + num_workers: 8 + reward_functions: + - name: format + weight: 0.1 + - name: math_expr + weight: 0.9 + refcoco: + num_workers: 8 + reward_functions: + - name: format + weight: 0.1 + - name: bbox_giou + weight: 0.9 + kwargs: + giou_penalty_thres: 0.5 +logger: + log_dir: logs + num_val_samples_to_print: 0 + wandb_enabled: false + tensorboard_enabled: true + swanlab_enabled: false + mlflow_enabled: false + monitor_gpus: false + wandb: + project: grpo-dev + name: vlm-grpo-3b-megatron + tensorboard: {} + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 +cluster: + gpus_per_node: 2 + num_nodes: 1 diff --git a/examples/run_vlm_grpo.py b/examples/run_vlm_grpo.py index ef69d42528..4ac0922d5b 100644 --- a/examples/run_vlm_grpo.py +++ b/examples/run_vlm_grpo.py @@ -194,16 +194,29 @@ def hf_data_processor( length = sum(len(m["token_ids"]) for m in message_log) loss_multiplier = 1.0 - if length > max_seq_length: + if length >= max_seq_length: + # Treat truncated messages as text only + vllm_kwargs = { + "vllm_content": None, + "vllm_images": [], + } + # make smaller and mask out for chat_message in message_log: chat_message["token_ids"] = chat_message["token_ids"][ : min(4, max_seq_length // len(message_log)) ] + for key, value in chat_message.items(): + if isinstance(value, PackedTensor): + chat_message[key] = PackedTensor.empty_like(value) loss_multiplier = 0.0 - raise NotImplementedError( - "Sequence length is too long, please use a shorter sequence length" - ) + else: + # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation + # add images for vllm serving + vllm_kwargs = { + "vllm_content": string_formatted_dialog, + "vllm_images": images, + } output: DatumSpec = { "message_log": message_log, @@ -212,10 +225,7 @@ def hf_data_processor( "loss_multiplier": loss_multiplier, "idx": idx, "task_name": task_data_spec.task_name, - # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation - # add images for vllm serving - "vllm_content": string_formatted_dialog, - "vllm_images": images, + **vllm_kwargs, } return output diff --git a/nemo_rl/data/multimodal_utils.py b/nemo_rl/data/multimodal_utils.py index 74e5a73a8c..0da507acc7 100644 --- a/nemo_rl/data/multimodal_utils.py +++ b/nemo_rl/data/multimodal_utils.py @@ -30,34 +30,47 @@ class PackedTensor: """ def __init__( - self, tensors: Union[torch.Tensor, list[torch.Tensor]], dim_to_pack: int + self, + tensors: Union[torch.Tensor, list[Optional[torch.Tensor]], list[None]], + dim_to_pack: int, ) -> None: assert tensors is not None, "Input tensors to PackedTensor cannot be None" if isinstance(tensors, torch.Tensor): - self.tensors: list[torch.Tensor] = [tensors] + self.tensors: list[Optional[torch.Tensor]] = [tensors] elif isinstance(tensors, list): assert len(tensors) > 0, ( "Input tensors to PackedTensor must be a non-empty list" ) - self.tensors: list[torch.Tensor] = tensors + self.tensors: list[Optional[torch.Tensor]] = tensors else: raise ValueError( f"Unsupported type for input tensors to PackedTensor: {type(tensors)}" ) self.dim_to_pack = dim_to_pack - def as_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor: + def as_tensor( + self, device: Optional[torch.device] = None + ) -> Optional[torch.Tensor]: if device is not None: - self.tensors = [item.to(device) for item in self.tensors] - return torch.cat(self.tensors, dim=self.dim_to_pack).to(device) + # Move only non-None tensors to device, preserve Nones + for i, item in enumerate(self.tensors): + if item is not None: + self.tensors[i] = item.to(device) + non_none_tensors = [t for t in self.tensors if t is not None] + if len(non_none_tensors) == 0: + return None + else: + return torch.cat(non_none_tensors, dim=self.dim_to_pack).to(device) def __len__(self) -> int: # this is the number of tensors in this data wrapper return len(self.tensors) def to(self, device: str | torch.device) -> "PackedTensor": - self.tensors = [item.to(device) for item in self.tensors] + self.tensors = [ + item.to(device) if item is not None else None for item in self.tensors + ] return self def slice(self, indices: Union[list[int], torch.Tensor]) -> "PackedTensor": @@ -65,6 +78,11 @@ def slice(self, indices: Union[list[int], torch.Tensor]) -> "PackedTensor": tensors = [self.tensors[i] for i in idx] return PackedTensor(tensors, self.dim_to_pack) + @classmethod + def empty_like(cls, other: "PackedTensor") -> "PackedTensor": + """Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None.""" + return cls([None] * len(other.tensors), other.dim_to_pack) + @classmethod def concat(cls, from_packed_tensors: list["PackedTensor"]) -> "PackedTensor": """Concatenate a list of PackedTensor objects into a single PackedTensor. diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index 4d4187a46b..5c0cd81003 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -126,7 +126,7 @@ def from_batches( item for sublist in list_of_tensors for item in sublist ] elif isinstance(list_of_tensors[0], PackedTensor): - tensor_or_list = PackedTensor.flattened_concat(list_of_tensors) + tensor_or_list = PackedTensor.concat(list_of_tensors) elif all(x.ndim == 1 for x in list_of_tensors): tensor_or_list = torch.cat(list_of_tensors) elif isinstance(list_of_tensors[0], torch.Tensor): diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py index 0243464e56..d4a8cd88ef 100644 --- a/nemo_rl/models/generation/vllm/utils.py +++ b/nemo_rl/models/generation/vllm/utils.py @@ -67,7 +67,10 @@ def _get_regular_prompt(index: int): prompt_dict = {"prompt": msg} # add additional data if present images = data.get("vllm_images", None) - if images is not None: + if images is None or len(images[i]) == 0: + prompts.append(_get_regular_prompt(i)) + continue + else: prompt_dict["multi_modal_data"] = { "image": images[i][0] if len(images[i]) == 1 else images[i] } diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 38078dca13..87a0ddb1d5 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -342,12 +342,19 @@ def forward_step_arbitrary_loss( pad_mask_loss=False, ) + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids_cp_sharded.device + ) + if len(multimodal_data) > 0: + position_ids = None + with straggler_timer: output_tensor = model( - input_ids_cp_sharded, - position_ids, - attention_mask, + input_ids=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, packed_seq_params=packed_seq_params, + **multimodal_data, ) # Apply temperature scaling to logits for training diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 1f87676060..5292128451 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -252,6 +252,9 @@ def freeze_moe_router(megatron_model): # Handle both wrapped (Float16Module) and unwrapped models if isinstance(model_module, Float16Module): model_module = model_module.module + # Handle VLM models + if hasattr(model_module, "language_model"): + model_module = model_module.language_model for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router.weight.requires_grad = False @@ -265,6 +268,9 @@ def re_enable_float32_expert_bias(megatron_model): # Handle both wrapped (Float16Module) and unwrapped models if isinstance(model_module, Float16Module): model_module = model_module.module + # Handle VLM models + if hasattr(model_module, "language_model"): + model_module = model_module.language_model for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router._maintain_float32_expert_bias() @@ -1199,11 +1205,18 @@ def forward_step_fn( packed_seq_params = None unpacked_input_ids = input_ids + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + if len(multimodal_data) > 0: + position_ids = None + output_tensor = model( - input_ids_cp_sharded, - position_ids, - attention_mask, + input_ids=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, packed_seq_params=packed_seq_params, + **multimodal_data, ) # Apply temperature scaling to logits for training diff --git a/pyproject.toml b/pyproject.toml index 69d3d9fea4..36e24a6365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,9 +25,7 @@ dependencies = [ "triton", "colored==2.2.3", "ray[default]==2.46.0", - # transformers==4.54.0/4.54.1 both fail on rm models - # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved - "transformers>=4.51.0,<4.54.0", + "transformers>=4.55.4", "wandb", "numpy", "datasets>=4.0.0", diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 76e4e55429..9e7f8ff3be 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -21,7 +21,10 @@ tests/test_suites/llm/grpo-moonlight-16ba3b-4n8g-megatron.sh # Functional VLM run tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.sh -tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh +tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh + +# Removing this until this issue is resolved: https://github.com/huggingface/transformers/issues/41190 +# tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh # Deepscaler (short tests) tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh diff --git a/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh b/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh new file mode 100755 index 0000000000..b3c6764f65 --- /dev/null +++ b/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=200 +MAX_STEPS=200 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_vlm_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["200"] < 0.1' \ + 'data["train/reward"]["200"] > 0.9' +fi + diff --git a/tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh b/tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh.disabled similarity index 100% rename from tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh rename to tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh.disabled diff --git a/tests/unit/data/test_multimodal_dict.py b/tests/unit/data/test_multimodal_dict.py index ff95534e83..a94412222a 100644 --- a/tests/unit/data/test_multimodal_dict.py +++ b/tests/unit/data/test_multimodal_dict.py @@ -316,3 +316,36 @@ def test_get_multimodal_dict(): assert "token_type_ids" in mm_dict assert isinstance(mm_dict["image_features"], PackedTensor) assert torch.is_tensor(mm_dict["token_type_ids"]) + + +def test_packedtensor_all_none(): + pt = PackedTensor([None, None], dim_to_pack=0) + assert pt.as_tensor() is None + + +def test_packedtensor_with_none_entry(): + original = PackedTensor([torch.randn(2, 3), None], dim_to_pack=0) + empty = PackedTensor.empty_like(original) + # same logical length + assert len(empty) == len(original) + # all entries are None, thus as_tensor returns None + assert empty.as_tensor() is None + + +def test_packedtensor_to_with_none_entry(): + t = torch.randn(1, 2) + pt = PackedTensor([None, t], dim_to_pack=0) + pt = pt.to("cpu") + assert pt.tensors[0] is None + assert isinstance(pt.tensors[1], torch.Tensor) + assert pt.tensors[1].device.type == "cpu" + + +def test_packedtensor_as_tensor_with_mixed_none_and_tensors(): + t1 = torch.randn(2, 3) + t2 = None + t3 = torch.randn(4, 3) + pt = PackedTensor([t1, t2, t3], dim_to_pack=0) + out = pt.as_tensor() + expected = torch.cat([t1, t3], dim=0) + assert torch.equal(out, expected) diff --git a/tests/unit/models/generation/test_vllm_utils.py b/tests/unit/models/generation/test_vllm_utils.py new file mode 100644 index 0000000000..4093b4c5ae --- /dev/null +++ b/tests/unit/models/generation/test_vllm_utils.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.generation.vllm.utils import ( + format_prompt_for_vllm_generation, +) + + +def _mk_inputs(batch_size: int = 2, seq_len: int = 5): + input_ids = torch.arange(batch_size * seq_len).view(batch_size, seq_len) + # make second example shorter + input_lengths = torch.tensor([seq_len, seq_len - 2]) + return input_ids, input_lengths + + +def test_vllm_utils_regular_llm_path(): + input_ids, input_lengths = _mk_inputs() + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + } + ) + prompts = format_prompt_for_vllm_generation(data) + assert isinstance(prompts, list) and len(prompts) == 2 + # first has full length + assert prompts[0]["prompt_token_ids"] == input_ids[0].tolist() + # second trimmed by input_lengths + assert prompts[1]["prompt_token_ids"] == input_ids[1, : input_lengths[1]].tolist() + + +def test_vllm_utils_vlm_with_images_and_text(): + # Batch with two samples + # both have content; first has one image, second has two images + input_ids, input_lengths = _mk_inputs() + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "vllm_content": ["user: hi", "user: hello"], + "vllm_images": [["img1"], ["img2a", "img2b"]], + } + ) + + prompts = format_prompt_for_vllm_generation(data) + assert len(prompts) == 2 + assert prompts[0]["prompt"] == "user: hi" + assert prompts[0]["multi_modal_data"]["image"] == "img1" + assert prompts[1]["prompt"] == "user: hello" + assert prompts[1]["multi_modal_data"]["image"] == ["img2a", "img2b"] + + +def test_vllm_utils_vlm_with_missing_images_fallback_to_tokens(): + input_ids, input_lengths = _mk_inputs() + # images None triggers fallback + data_none = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "vllm_content": ["a", "b"], + "vllm_images": None, + } + ) + prompts = format_prompt_for_vllm_generation(data_none) + assert all("prompt_token_ids" in p for p in prompts) + + # images empty per sample also triggers fallback + data_empty = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "vllm_content": ["a", "b"], + "vllm_images": [[], []], + } + ) + prompts = format_prompt_for_vllm_generation(data_empty) + assert all("prompt_token_ids" in p for p in prompts) + + +def test_vllm_utils_vlm_with_none_content_fallback_to_tokens_and_sample_idx(): + input_ids, input_lengths = _mk_inputs() + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "vllm_content": [None, None], + "vllm_images": [["img"], ["img"]], + } + ) + # even though images provided, None content should fallback to tokens + prompts_all = format_prompt_for_vllm_generation(data) + assert len(prompts_all) == 2 + assert all("prompt_token_ids" in p for p in prompts_all) + + # single-sample API + p0 = format_prompt_for_vllm_generation(data, sample_idx=0) + p1 = format_prompt_for_vllm_generation(data, sample_idx=1) + assert isinstance(p0, dict) and isinstance(p1, dict) + assert "prompt_token_ids" in p0 and "prompt_token_ids" in p1 diff --git a/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py b/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py new file mode 100644 index 0000000000..c1e4e927b9 --- /dev/null +++ b/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from transformers import AutoModelForImageTextToText, AutoProcessor + + +class SmolVLMVisionEmbeddingsReference(nn.Module): + """ + Previous (correct) implementation in transformers<=4.54.1. Copied from https://github.com/huggingface/transformers/blob/4.54.1/src/transformers/models/smolvlm/modeling_smolvlm.py#L101-L156 + + Remove this test once upstream bug is fixed. + """ + + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +def test_smolvlm_embeddings_differ_from_reference(): + # Remove once https://github.com/huggingface/transformers/issues/41190 is fixed and adopted. + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" + processor = AutoProcessor.from_pretrained(model_path) + model = AutoModelForImageTextToText.from_pretrained( + model_path, torch_dtype=torch.bfloat16 + ) + model = model.to(device) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Can you describe this image?"}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() + } + inputs = { + k: v.to(dtype=torch.bfloat16) + if isinstance(v, torch.Tensor) and v.is_floating_point() + else v + for k, v in inputs.items() + } + + patch_size = model.model.vision_model.patch_size + pixel_values = inputs["pixel_values"] # (bsz, num_images, 3, H, W) + bsz, num_images, _, H, W = pixel_values.shape + pixel_values = pixel_values.view(bsz * num_images, *pixel_values.shape[2:]) + + patch_attention_mask = torch.ones( + ( + bsz, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ), + device=pixel_values.device, + dtype=torch.bool, + ) + + # Get buggy/current embeddings module from installed transformers + embeddings_buggy = model.model.vision_model.embeddings + + with torch.no_grad(): + out_buggy = embeddings_buggy( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + # Build reference embeddings and copy weights for apples-to-apples comparison + ref = SmolVLMVisionEmbeddingsReference(model.model.vision_model.config) + ref = ref.to(device=device, dtype=torch.bfloat16) + + # Copy the conv and embedding weights + ref.patch_embedding.load_state_dict(embeddings_buggy.patch_embedding.state_dict()) + ref.position_embedding.load_state_dict( + embeddings_buggy.position_embedding.state_dict() + ) + + with torch.no_grad(): + out_ref = ref( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + # Assert outputs differ due to the upstream bug + are_equal = torch.allclose(out_buggy.float(), out_ref.float(), atol=0, rtol=0) + assert not are_equal, ( + "If this fails, that means the upstream bug has been fixed. You can close this issue: https://github.com/huggingface/transformers/issues/41190" + ) diff --git a/uv.lock b/uv.lock index 50e2abc88f..f93fc1053e 100644 --- a/uv.lock +++ b/uv.lock @@ -2335,7 +2335,7 @@ requires-dist = [ { name = "torch", marker = "sys_platform == 'darwin'", index = "https://pypi.org/simple" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "transformer-engine", extras = ["pytorch"], marker = "sys_platform != 'darwin'", specifier = ">=2.5.0a0,<2.6.0" }, - { name = "transformers", specifier = ">=4.51.3" }, + { name = "transformers", specifier = ">=4.55.0" }, { name = "typing-extensions" }, { name = "wandb", specifier = ">=0.19.10" }, ] @@ -3041,7 +3041,7 @@ requires-dist = [ { name = "torchvision", marker = "sys_platform != 'darwin'", specifier = ">=0.22.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.22.0", index = "https://pypi.org/simple" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'mcore'", specifier = "==2.5.0" }, - { name = "transformers", specifier = ">=4.51.0,<4.54.0" }, + { name = "transformers", specifier = ">=4.55.4" }, { name = "triton", marker = "sys_platform != 'darwin'", index = "https://download.pytorch.org/whl/cu128" }, { name = "triton", marker = "sys_platform == 'darwin'", index = "https://pypi.org/simple" }, { name = "vllm", marker = "extra == 'automodel'", specifier = "==0.10.0" }, @@ -6052,7 +6052,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a1/1d/73ec467d20d96e0bb [[package]] name = "transformers" -version = "4.53.3" +version = "4.55.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -6066,9 +6066,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/5c/49182918b58eaa0b4c954fd0e37c79fc299e5643e69d70089d0b0eb0cd9b/transformers-4.53.3.tar.gz", hash = "sha256:b2eda1a261de79b78b97f7888fe2005fc0c3fabf5dad33d52cc02983f9f675d8", size = 9197478, upload-time = "2025-07-22T07:30:51.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/43/3cb831d5f28cc723516e5bb43a8c6042aca3038bb36b6bd6016b40dfd1e8/transformers-4.55.4.tar.gz", hash = "sha256:574a30559bc273c7a4585599ff28ab6b676e96dc56ffd2025ecfce2fd0ab915d", size = 9573015, upload-time = "2025-08-22T15:18:43.192Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382, upload-time = "2025-07-22T07:30:48.458Z" }, + { url = "https://files.pythonhosted.org/packages/fa/0a/8791a6ee0529c45f669566969e99b75e2ab20eb0bfee8794ce295c18bdad/transformers-4.55.4-py3-none-any.whl", hash = "sha256:df28f3849665faba4af5106f0db4510323277c4bb595055340544f7e59d06458", size = 11269659, upload-time = "2025-08-22T15:18:40.025Z" }, ] [[package]]