diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
index 85a37ffdf0..9d69624cb7 160000
--- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
+++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
@@ -1 +1 @@
-Subproject commit 85a37ffdf02edc07c0a7ac97cb9abcafcd0ac0ed
+Subproject commit 9d69624cb75e46f06ddfadd9a726acecfcf8b064
diff --git a/3rdparty/Megatron-Bridge-workspace/setup.py b/3rdparty/Megatron-Bridge-workspace/setup.py
index 06657bab31..9797c340de 100644
--- a/3rdparty/Megatron-Bridge-workspace/setup.py
+++ b/3rdparty/Megatron-Bridge-workspace/setup.py
@@ -33,7 +33,7 @@
"packaging",
"tensorboard>=2.19.0",
"torch",
- "transformers>=4.51.3",
+ "transformers>=4.55.0",
"typing-extensions",
"rich",
"wandb>=0.19.10",
diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml
index 2d39d9cd7f..a5da6ed98f 100644
--- a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml
+++ b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml
@@ -3,12 +3,3 @@ checkpointing:
checkpoint_dir: results/clevr_grpo
policy:
max_total_sequence_length: 3072
-env:
- refcoco:
- reward_functions:
- - name: format
- weight: 0.1
- - name: bbox_giou
- weight: 0.9
- kwargs:
- giou_penalty_thres: 1.0
diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml
new file mode 100644
index 0000000000..c8657ef818
--- /dev/null
+++ b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.yaml
@@ -0,0 +1,25 @@
+defaults: ../../vlm_grpo_3B.yaml
+checkpointing:
+ checkpoint_dir: results/clevr_grpo
+policy:
+ max_total_sequence_length: 3072
+ dtensor_cfg:
+ enabled: false
+ dynamic_batching:
+ enabled: false
+ make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
+ optimizer: null
+ megatron_cfg:
+ enabled: true
+ empty_unused_memory_level: 1
+ optimizer:
+ lr: 5.0e-07
+ min_lr: 5.0e-08
+ scheduler:
+ lr_warmup_iters: 50
+ lr_warmup_init: 5.0e-08
+ distributed_data_parallel_config:
+ overlap_grad_reduce: false
+logger:
+ wandb:
+ name: vlm-grpo-3b-megatron
diff --git a/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml.disabled
similarity index 100%
rename from examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml
rename to examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml.disabled
diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml
index 3c61241714..460bc3474d 100644
--- a/examples/configs/vlm_grpo_3B.yaml
+++ b/examples/configs/vlm_grpo_3B.yaml
@@ -58,6 +58,70 @@ policy:
context_parallel_size: 1
custom_parallel_plan: null
+ megatron_cfg:
+ enabled: false
+ empty_unused_memory_level: 0
+ activation_checkpointing: false
+ converter_type: "Qwen2ForCausalLM"
+ tensor_model_parallel_size: 1
+ expert_tensor_parallel_size: 1
+ expert_model_parallel_size: 1
+ pipeline_model_parallel_size: 1
+ num_layers_in_first_pipeline_stage: null
+ num_layers_in_last_pipeline_stage: null
+ context_parallel_size: 1
+ pipeline_dtype: ${policy.precision}
+ sequence_parallel: false
+ freeze_moe_router: true
+ moe_router_dtype: "fp64"
+ moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
+ moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
+ moe_permute_fusion: false
+ #gives ~20% training perf speedup with sequence packing
+ apply_rope_fusion: True
+ defer_fp32_logits: null
+
+ optimizer:
+ optimizer: "adam"
+ lr: 5.0e-6
+ min_lr: 5.0e-7
+ weight_decay: 0.01
+ bf16: true
+ fp16: false
+ params_dtype: "float32"
+
+ #adam
+ adam_beta1: 0.9
+ adam_beta2: 0.999
+ adam_eps: 1e-8
+
+ #sgd
+ sgd_momentum: 0.9
+
+ #distributed optimizer
+ use_distributed_optimizer: true
+ use_precision_aware_optimizer: true
+
+ clip_grad: ${policy.max_grad_norm}
+
+ scheduler:
+ start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ weight_decay_incr_style: "constant"
+ lr_decay_style: "constant"
+ lr_decay_iters: 1000
+ lr_warmup_iters: 13
+ lr_warmup_init: 5.0e-7
+
+ distributed_data_parallel_config:
+ grad_reduce_in_fp32: false
+ overlap_grad_reduce: true
+ overlap_param_gather: true
+ average_in_collective: true
+ use_custom_fsdp: false
+ data_parallel_sharding_strategy: "optim_grads_params"
+
+
# dynamic_batching improves performance by ensuring logprob and training microbatches
# have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
# responses are sorted by sequence length and bucketed into microbatches with a total
@@ -76,6 +140,10 @@ policy:
sequence_packing:
enabled: False
+ train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+ logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+ algorithm: "modified_first_fit_decreasing"
+ sequence_length_round: 64
optimizer:
name: "torch.optim.AdamW"
diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml
new file mode 100644
index 0000000000..ca7b03301e
--- /dev/null
+++ b/examples/configs/vlm_grpo_3B_megatron.yaml
@@ -0,0 +1,200 @@
+grpo:
+ num_prompts_per_step: 8
+ num_generations_per_prompt: 16
+ max_rollout_turns: 1
+ max_num_epochs: 1
+ max_num_steps: 1000000
+ normalize_rewards: true
+ use_leave_one_out_baseline: true
+ val_period: 10
+ val_at_start: false
+ overlong_filtering: false
+ max_val_samples: 256
+ val_batch_size: 256
+ seed: 42
+ async_grpo:
+ enabled: false
+ max_trajectory_age_steps: 1
+loss_fn:
+ reference_policy_kl_penalty: 0.01
+ ratio_clip_min: 0.2
+ ratio_clip_max: 0.2
+ ratio_clip_c: null
+ use_on_policy_kl_approximation: false
+ use_importance_sampling_correction: false
+ token_level_loss: true
+checkpointing:
+ enabled: true
+ checkpoint_dir: results/clevr_grpo_${policy.model_name}
+ metric_name: val_reward
+ higher_is_better: true
+ keep_top_k: 3
+ save_period: 10
+ checkpoint_must_save_by: null
+policy:
+ model_name: Qwen/Qwen2.5-VL-3B-Instruct
+ tokenizer:
+ name: ${policy.model_name}
+ train_global_batch_size: 128
+ train_micro_batch_size: 1
+ generation_batch_size: 32
+ logprob_batch_size: 4
+ max_total_sequence_length: 2048
+ precision: bfloat16
+ dtensor_cfg:
+ _v2: true
+ enabled: false
+ cpu_offload: false
+ sequence_parallel: false
+ activation_checkpointing: false
+ tensor_parallel_size: 1
+ context_parallel_size: 1
+ custom_parallel_plan: null
+ dynamic_batching:
+ enabled: false
+ train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+ logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+ sequence_length_round: 64
+ make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
+ max_grad_norm: 1.0
+ sequence_packing:
+ enabled: false
+ train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+ logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+ algorithm: modified_first_fit_decreasing
+ sequence_length_round: 64
+ optimizer: null
+ scheduler:
+ - name: torch.optim.lr_scheduler.LinearLR
+ kwargs:
+ start_factor: 0.1
+ end_factor: 1.0
+ total_iters: 50
+ - name: torch.optim.lr_scheduler.ConstantLR
+ kwargs:
+ factor: 1.0
+ total_iters: 10000000000
+ - milestones:
+ - 50
+ generation:
+ backend: vllm
+ max_new_tokens: 1024
+ temperature: 1.0
+ top_p: 1.0
+ top_k: null
+ stop_token_ids: null
+ stop_strings: null
+ vllm_cfg:
+ async_engine: false
+ precision: ${policy.precision}
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ expert_parallel_size: 1
+ gpu_memory_utilization: 0.6
+ max_model_len: ${policy.max_total_sequence_length}
+ enforce_eager: false
+ enable_expert_parallel: false
+ colocated:
+ enabled: true
+ resources:
+ gpus_per_node: null
+ num_nodes: null
+ megatron_cfg:
+ enabled: true
+ empty_unused_memory_level: 0
+ activation_checkpointing: false
+ converter_type: Qwen2ForCausalLM
+ tensor_model_parallel_size: 1
+ expert_tensor_parallel_size: 1
+ expert_model_parallel_size: 1
+ pipeline_model_parallel_size: 1
+ num_layers_in_first_pipeline_stage: null
+ num_layers_in_last_pipeline_stage: null
+ context_parallel_size: 1
+ pipeline_dtype: ${policy.precision}
+ sequence_parallel: false
+ freeze_moe_router: true
+ moe_router_dtype: fp64
+ moe_router_load_balancing_type: none
+ moe_router_bias_update_rate: 0.0
+ moe_permute_fusion: false
+ apply_rope_fusion: true
+ optimizer:
+ optimizer: adam
+ lr: 2.0e-07
+ min_lr: 2.0e-07
+ weight_decay: 0.01
+ bf16: true
+ fp16: false
+ params_dtype: float32
+ adam_beta1: 0.9
+ adam_beta2: 0.999
+ adam_eps: 1.0e-08
+ sgd_momentum: 0.9
+ use_distributed_optimizer: true
+ use_precision_aware_optimizer: true
+ clip_grad: ${policy.max_grad_norm}
+ scheduler:
+ start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ weight_decay_incr_style: constant
+ lr_decay_style: constant
+ lr_decay_iters: 1000
+ lr_warmup_iters: 50
+ lr_warmup_init: 2.0e-08
+ distributed_data_parallel_config:
+ grad_reduce_in_fp32: false
+ overlap_grad_reduce: false
+ overlap_param_gather: true
+ average_in_collective: true
+ use_custom_fsdp: false
+ data_parallel_sharding_strategy: optim_grads_params
+data:
+ max_input_seq_length: ${policy.max_total_sequence_length}
+ prompt_file: examples/prompts/clevr_cogent_cot.txt
+ system_prompt_file: null
+ dataset_name: clevr-cogent
+ split: trainA
+ shuffle: true
+env:
+ clevr-cogent:
+ num_workers: 8
+ reward_functions:
+ - name: format
+ weight: 0.2
+ - name: exact_alnum
+ weight: 0.8
+ geometry3k:
+ num_workers: 8
+ reward_functions:
+ - name: format
+ weight: 0.1
+ - name: math_expr
+ weight: 0.9
+ refcoco:
+ num_workers: 8
+ reward_functions:
+ - name: format
+ weight: 0.1
+ - name: bbox_giou
+ weight: 0.9
+ kwargs:
+ giou_penalty_thres: 0.5
+logger:
+ log_dir: logs
+ num_val_samples_to_print: 0
+ wandb_enabled: false
+ tensorboard_enabled: true
+ swanlab_enabled: false
+ mlflow_enabled: false
+ monitor_gpus: false
+ wandb:
+ project: grpo-dev
+ name: vlm-grpo-3b-megatron
+ tensorboard: {}
+ gpu_monitoring:
+ collection_interval: 10
+ flush_interval: 10
+cluster:
+ gpus_per_node: 2
+ num_nodes: 1
diff --git a/examples/run_vlm_grpo.py b/examples/run_vlm_grpo.py
index ef69d42528..4ac0922d5b 100644
--- a/examples/run_vlm_grpo.py
+++ b/examples/run_vlm_grpo.py
@@ -194,16 +194,29 @@ def hf_data_processor(
length = sum(len(m["token_ids"]) for m in message_log)
loss_multiplier = 1.0
- if length > max_seq_length:
+ if length >= max_seq_length:
+ # Treat truncated messages as text only
+ vllm_kwargs = {
+ "vllm_content": None,
+ "vllm_images": [],
+ }
+
# make smaller and mask out
for chat_message in message_log:
chat_message["token_ids"] = chat_message["token_ids"][
: min(4, max_seq_length // len(message_log))
]
+ for key, value in chat_message.items():
+ if isinstance(value, PackedTensor):
+ chat_message[key] = PackedTensor.empty_like(value)
loss_multiplier = 0.0
- raise NotImplementedError(
- "Sequence length is too long, please use a shorter sequence length"
- )
+ else:
+ # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
+ # add images for vllm serving
+ vllm_kwargs = {
+ "vllm_content": string_formatted_dialog,
+ "vllm_images": images,
+ }
output: DatumSpec = {
"message_log": message_log,
@@ -212,10 +225,7 @@ def hf_data_processor(
"loss_multiplier": loss_multiplier,
"idx": idx,
"task_name": task_data_spec.task_name,
- # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
- # add images for vllm serving
- "vllm_content": string_formatted_dialog,
- "vllm_images": images,
+ **vllm_kwargs,
}
return output
diff --git a/nemo_rl/data/multimodal_utils.py b/nemo_rl/data/multimodal_utils.py
index 74e5a73a8c..0da507acc7 100644
--- a/nemo_rl/data/multimodal_utils.py
+++ b/nemo_rl/data/multimodal_utils.py
@@ -30,34 +30,47 @@ class PackedTensor:
"""
def __init__(
- self, tensors: Union[torch.Tensor, list[torch.Tensor]], dim_to_pack: int
+ self,
+ tensors: Union[torch.Tensor, list[Optional[torch.Tensor]], list[None]],
+ dim_to_pack: int,
) -> None:
assert tensors is not None, "Input tensors to PackedTensor cannot be None"
if isinstance(tensors, torch.Tensor):
- self.tensors: list[torch.Tensor] = [tensors]
+ self.tensors: list[Optional[torch.Tensor]] = [tensors]
elif isinstance(tensors, list):
assert len(tensors) > 0, (
"Input tensors to PackedTensor must be a non-empty list"
)
- self.tensors: list[torch.Tensor] = tensors
+ self.tensors: list[Optional[torch.Tensor]] = tensors
else:
raise ValueError(
f"Unsupported type for input tensors to PackedTensor: {type(tensors)}"
)
self.dim_to_pack = dim_to_pack
- def as_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor:
+ def as_tensor(
+ self, device: Optional[torch.device] = None
+ ) -> Optional[torch.Tensor]:
if device is not None:
- self.tensors = [item.to(device) for item in self.tensors]
- return torch.cat(self.tensors, dim=self.dim_to_pack).to(device)
+ # Move only non-None tensors to device, preserve Nones
+ for i, item in enumerate(self.tensors):
+ if item is not None:
+ self.tensors[i] = item.to(device)
+ non_none_tensors = [t for t in self.tensors if t is not None]
+ if len(non_none_tensors) == 0:
+ return None
+ else:
+ return torch.cat(non_none_tensors, dim=self.dim_to_pack).to(device)
def __len__(self) -> int:
# this is the number of tensors in this data wrapper
return len(self.tensors)
def to(self, device: str | torch.device) -> "PackedTensor":
- self.tensors = [item.to(device) for item in self.tensors]
+ self.tensors = [
+ item.to(device) if item is not None else None for item in self.tensors
+ ]
return self
def slice(self, indices: Union[list[int], torch.Tensor]) -> "PackedTensor":
@@ -65,6 +78,11 @@ def slice(self, indices: Union[list[int], torch.Tensor]) -> "PackedTensor":
tensors = [self.tensors[i] for i in idx]
return PackedTensor(tensors, self.dim_to_pack)
+ @classmethod
+ def empty_like(cls, other: "PackedTensor") -> "PackedTensor":
+ """Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None."""
+ return cls([None] * len(other.tensors), other.dim_to_pack)
+
@classmethod
def concat(cls, from_packed_tensors: list["PackedTensor"]) -> "PackedTensor":
"""Concatenate a list of PackedTensor objects into a single PackedTensor.
diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py
index 4d4187a46b..5c0cd81003 100644
--- a/nemo_rl/distributed/batched_data_dict.py
+++ b/nemo_rl/distributed/batched_data_dict.py
@@ -126,7 +126,7 @@ def from_batches(
item for sublist in list_of_tensors for item in sublist
]
elif isinstance(list_of_tensors[0], PackedTensor):
- tensor_or_list = PackedTensor.flattened_concat(list_of_tensors)
+ tensor_or_list = PackedTensor.concat(list_of_tensors)
elif all(x.ndim == 1 for x in list_of_tensors):
tensor_or_list = torch.cat(list_of_tensors)
elif isinstance(list_of_tensors[0], torch.Tensor):
diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py
index 0243464e56..d4a8cd88ef 100644
--- a/nemo_rl/models/generation/vllm/utils.py
+++ b/nemo_rl/models/generation/vllm/utils.py
@@ -67,7 +67,10 @@ def _get_regular_prompt(index: int):
prompt_dict = {"prompt": msg}
# add additional data if present
images = data.get("vllm_images", None)
- if images is not None:
+ if images is None or len(images[i]) == 0:
+ prompts.append(_get_regular_prompt(i))
+ continue
+ else:
prompt_dict["multi_modal_data"] = {
"image": images[i][0] if len(images[i]) == 1 else images[i]
}
diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py
index 38078dca13..87a0ddb1d5 100644
--- a/nemo_rl/models/megatron/common.py
+++ b/nemo_rl/models/megatron/common.py
@@ -342,12 +342,19 @@ def forward_step_arbitrary_loss(
pad_mask_loss=False,
)
+ multimodal_data = data_dict.get_multimodal_dict(
+ as_tensors=True, device=input_ids_cp_sharded.device
+ )
+ if len(multimodal_data) > 0:
+ position_ids = None
+
with straggler_timer:
output_tensor = model(
- input_ids_cp_sharded,
- position_ids,
- attention_mask,
+ input_ids=input_ids_cp_sharded,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
+ **multimodal_data,
)
# Apply temperature scaling to logits for training
diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py
index 1f87676060..5292128451 100644
--- a/nemo_rl/models/policy/megatron_policy_worker.py
+++ b/nemo_rl/models/policy/megatron_policy_worker.py
@@ -252,6 +252,9 @@ def freeze_moe_router(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
+ # Handle VLM models
+ if hasattr(model_module, "language_model"):
+ model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
@@ -265,6 +268,9 @@ def re_enable_float32_expert_bias(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
+ # Handle VLM models
+ if hasattr(model_module, "language_model"):
+ model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router._maintain_float32_expert_bias()
@@ -1199,11 +1205,18 @@ def forward_step_fn(
packed_seq_params = None
unpacked_input_ids = input_ids
+ multimodal_data = data_dict.get_multimodal_dict(
+ as_tensors=True, device=input_ids.device
+ )
+ if len(multimodal_data) > 0:
+ position_ids = None
+
output_tensor = model(
- input_ids_cp_sharded,
- position_ids,
- attention_mask,
+ input_ids=input_ids_cp_sharded,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
+ **multimodal_data,
)
# Apply temperature scaling to logits for training
diff --git a/pyproject.toml b/pyproject.toml
index 69d3d9fea4..36e24a6365 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,9 +25,7 @@ dependencies = [
"triton",
"colored==2.2.3",
"ray[default]==2.46.0",
- # transformers==4.54.0/4.54.1 both fail on rm models
- # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
- "transformers>=4.51.0,<4.54.0",
+ "transformers>=4.55.4",
"wandb",
"numpy",
"datasets>=4.0.0",
diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt
index 76e4e55429..9e7f8ff3be 100644
--- a/tests/test_suites/nightly.txt
+++ b/tests/test_suites/nightly.txt
@@ -21,7 +21,10 @@ tests/test_suites/llm/grpo-moonlight-16ba3b-4n8g-megatron.sh
# Functional VLM run
tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.sh
-tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh
+tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh
+
+# Removing this until this issue is resolved: https://github.com/huggingface/transformers/issues/41190
+# tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh
# Deepscaler (short tests)
tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh
diff --git a/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh b/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh
new file mode 100755
index 0000000000..b3c6764f65
--- /dev/null
+++ b/tests/test_suites/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-megatrontp2.v1.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
+source $SCRIPT_DIR/common.env
+
+# ===== BEGIN CONFIG =====
+NUM_NODES=1
+STEPS_PER_RUN=200
+MAX_STEPS=200
+NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
+NUM_MINUTES=180
+# ===== END CONFIG =====
+
+exit_if_max_steps_reached
+
+# Run the experiment
+cd $PROJECT_ROOT
+uv run examples/run_vlm_grpo.py \
+ --config $CONFIG_PATH \
+ grpo.max_num_steps=$MAX_STEPS \
+ logger.log_dir=$LOG_DIR \
+ logger.wandb_enabled=True \
+ logger.wandb.project=nemo-rl \
+ logger.wandb.name=$EXP_NAME \
+ logger.monitor_gpus=True \
+ logger.tensorboard_enabled=True \
+ checkpointing.enabled=True \
+ checkpointing.checkpoint_dir=$CKPT_DIR \
+ $@ \
+ 2>&1 | tee $RUN_LOG
+
+# Convert tensorboard logs to json
+uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
+
+# Only run metrics if the target step is reached
+if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
+ uv run tests/check_metrics.py $JSON_METRICS \
+ 'data["train/loss"]["200"] < 0.1' \
+ 'data["train/reward"]["200"] > 0.9'
+fi
+
diff --git a/tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh b/tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh.disabled
similarity index 100%
rename from tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh
rename to tests/test_suites/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.sh.disabled
diff --git a/tests/unit/data/test_multimodal_dict.py b/tests/unit/data/test_multimodal_dict.py
index ff95534e83..a94412222a 100644
--- a/tests/unit/data/test_multimodal_dict.py
+++ b/tests/unit/data/test_multimodal_dict.py
@@ -316,3 +316,36 @@ def test_get_multimodal_dict():
assert "token_type_ids" in mm_dict
assert isinstance(mm_dict["image_features"], PackedTensor)
assert torch.is_tensor(mm_dict["token_type_ids"])
+
+
+def test_packedtensor_all_none():
+ pt = PackedTensor([None, None], dim_to_pack=0)
+ assert pt.as_tensor() is None
+
+
+def test_packedtensor_with_none_entry():
+ original = PackedTensor([torch.randn(2, 3), None], dim_to_pack=0)
+ empty = PackedTensor.empty_like(original)
+ # same logical length
+ assert len(empty) == len(original)
+ # all entries are None, thus as_tensor returns None
+ assert empty.as_tensor() is None
+
+
+def test_packedtensor_to_with_none_entry():
+ t = torch.randn(1, 2)
+ pt = PackedTensor([None, t], dim_to_pack=0)
+ pt = pt.to("cpu")
+ assert pt.tensors[0] is None
+ assert isinstance(pt.tensors[1], torch.Tensor)
+ assert pt.tensors[1].device.type == "cpu"
+
+
+def test_packedtensor_as_tensor_with_mixed_none_and_tensors():
+ t1 = torch.randn(2, 3)
+ t2 = None
+ t3 = torch.randn(4, 3)
+ pt = PackedTensor([t1, t2, t3], dim_to_pack=0)
+ out = pt.as_tensor()
+ expected = torch.cat([t1, t3], dim=0)
+ assert torch.equal(out, expected)
diff --git a/tests/unit/models/generation/test_vllm_utils.py b/tests/unit/models/generation/test_vllm_utils.py
new file mode 100644
index 0000000000..4093b4c5ae
--- /dev/null
+++ b/tests/unit/models/generation/test_vllm_utils.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.models.generation.vllm.utils import (
+ format_prompt_for_vllm_generation,
+)
+
+
+def _mk_inputs(batch_size: int = 2, seq_len: int = 5):
+ input_ids = torch.arange(batch_size * seq_len).view(batch_size, seq_len)
+ # make second example shorter
+ input_lengths = torch.tensor([seq_len, seq_len - 2])
+ return input_ids, input_lengths
+
+
+def test_vllm_utils_regular_llm_path():
+ input_ids, input_lengths = _mk_inputs()
+ data = BatchedDataDict(
+ {
+ "input_ids": input_ids,
+ "input_lengths": input_lengths,
+ }
+ )
+ prompts = format_prompt_for_vllm_generation(data)
+ assert isinstance(prompts, list) and len(prompts) == 2
+ # first has full length
+ assert prompts[0]["prompt_token_ids"] == input_ids[0].tolist()
+ # second trimmed by input_lengths
+ assert prompts[1]["prompt_token_ids"] == input_ids[1, : input_lengths[1]].tolist()
+
+
+def test_vllm_utils_vlm_with_images_and_text():
+ # Batch with two samples
+ # both have content; first has one image, second has two images
+ input_ids, input_lengths = _mk_inputs()
+ data = BatchedDataDict(
+ {
+ "input_ids": input_ids,
+ "input_lengths": input_lengths,
+ "vllm_content": ["user: hi", "user: hello"],
+ "vllm_images": [["img1"], ["img2a", "img2b"]],
+ }
+ )
+
+ prompts = format_prompt_for_vllm_generation(data)
+ assert len(prompts) == 2
+ assert prompts[0]["prompt"] == "user: hi"
+ assert prompts[0]["multi_modal_data"]["image"] == "img1"
+ assert prompts[1]["prompt"] == "user: hello"
+ assert prompts[1]["multi_modal_data"]["image"] == ["img2a", "img2b"]
+
+
+def test_vllm_utils_vlm_with_missing_images_fallback_to_tokens():
+ input_ids, input_lengths = _mk_inputs()
+ # images None triggers fallback
+ data_none = BatchedDataDict(
+ {
+ "input_ids": input_ids,
+ "input_lengths": input_lengths,
+ "vllm_content": ["a", "b"],
+ "vllm_images": None,
+ }
+ )
+ prompts = format_prompt_for_vllm_generation(data_none)
+ assert all("prompt_token_ids" in p for p in prompts)
+
+ # images empty per sample also triggers fallback
+ data_empty = BatchedDataDict(
+ {
+ "input_ids": input_ids,
+ "input_lengths": input_lengths,
+ "vllm_content": ["a", "b"],
+ "vllm_images": [[], []],
+ }
+ )
+ prompts = format_prompt_for_vllm_generation(data_empty)
+ assert all("prompt_token_ids" in p for p in prompts)
+
+
+def test_vllm_utils_vlm_with_none_content_fallback_to_tokens_and_sample_idx():
+ input_ids, input_lengths = _mk_inputs()
+ data = BatchedDataDict(
+ {
+ "input_ids": input_ids,
+ "input_lengths": input_lengths,
+ "vllm_content": [None, None],
+ "vllm_images": [["img"], ["img"]],
+ }
+ )
+ # even though images provided, None content should fallback to tokens
+ prompts_all = format_prompt_for_vllm_generation(data)
+ assert len(prompts_all) == 2
+ assert all("prompt_token_ids" in p for p in prompts_all)
+
+ # single-sample API
+ p0 = format_prompt_for_vllm_generation(data, sample_idx=0)
+ p1 = format_prompt_for_vllm_generation(data, sample_idx=1)
+ assert isinstance(p0, dict) and isinstance(p1, dict)
+ assert "prompt_token_ids" in p0 and "prompt_token_ids" in p1
diff --git a/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py b/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py
new file mode 100644
index 0000000000..c1e4e927b9
--- /dev/null
+++ b/tests/unit/models/huggingface/test_smolvlm_embeddings_bug.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+from transformers import AutoModelForImageTextToText, AutoProcessor
+
+
+class SmolVLMVisionEmbeddingsReference(nn.Module):
+ """
+ Previous (correct) implementation in transformers<=4.54.1. Copied from https://github.com/huggingface/transformers/blob/4.54.1/src/transformers/models/smolvlm/modeling_smolvlm.py#L101-L156
+
+ Remove this test once upstream bug is fixed.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+def test_smolvlm_embeddings_differ_from_reference():
+ # Remove once https://github.com/huggingface/transformers/issues/41190 is fixed and adopted.
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
+ processor = AutoProcessor.from_pretrained(model_path)
+ model = AutoModelForImageTextToText.from_pretrained(
+ model_path, torch_dtype=torch.bfloat16
+ )
+ model = model.to(device)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
+ },
+ {"type": "text", "text": "Can you describe this image?"},
+ ],
+ }
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ inputs = {
+ k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()
+ }
+ inputs = {
+ k: v.to(dtype=torch.bfloat16)
+ if isinstance(v, torch.Tensor) and v.is_floating_point()
+ else v
+ for k, v in inputs.items()
+ }
+
+ patch_size = model.model.vision_model.patch_size
+ pixel_values = inputs["pixel_values"] # (bsz, num_images, 3, H, W)
+ bsz, num_images, _, H, W = pixel_values.shape
+ pixel_values = pixel_values.view(bsz * num_images, *pixel_values.shape[2:])
+
+ patch_attention_mask = torch.ones(
+ (
+ bsz,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ ),
+ device=pixel_values.device,
+ dtype=torch.bool,
+ )
+
+ # Get buggy/current embeddings module from installed transformers
+ embeddings_buggy = model.model.vision_model.embeddings
+
+ with torch.no_grad():
+ out_buggy = embeddings_buggy(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ # Build reference embeddings and copy weights for apples-to-apples comparison
+ ref = SmolVLMVisionEmbeddingsReference(model.model.vision_model.config)
+ ref = ref.to(device=device, dtype=torch.bfloat16)
+
+ # Copy the conv and embedding weights
+ ref.patch_embedding.load_state_dict(embeddings_buggy.patch_embedding.state_dict())
+ ref.position_embedding.load_state_dict(
+ embeddings_buggy.position_embedding.state_dict()
+ )
+
+ with torch.no_grad():
+ out_ref = ref(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ # Assert outputs differ due to the upstream bug
+ are_equal = torch.allclose(out_buggy.float(), out_ref.float(), atol=0, rtol=0)
+ assert not are_equal, (
+ "If this fails, that means the upstream bug has been fixed. You can close this issue: https://github.com/huggingface/transformers/issues/41190"
+ )
diff --git a/uv.lock b/uv.lock
index 50e2abc88f..f93fc1053e 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2335,7 +2335,7 @@ requires-dist = [
{ name = "torch", marker = "sys_platform == 'darwin'", index = "https://pypi.org/simple" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "transformer-engine", extras = ["pytorch"], marker = "sys_platform != 'darwin'", specifier = ">=2.5.0a0,<2.6.0" },
- { name = "transformers", specifier = ">=4.51.3" },
+ { name = "transformers", specifier = ">=4.55.0" },
{ name = "typing-extensions" },
{ name = "wandb", specifier = ">=0.19.10" },
]
@@ -3041,7 +3041,7 @@ requires-dist = [
{ name = "torchvision", marker = "sys_platform != 'darwin'", specifier = ">=0.22.0", index = "https://download.pytorch.org/whl/cu128" },
{ name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.22.0", index = "https://pypi.org/simple" },
{ name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'mcore'", specifier = "==2.5.0" },
- { name = "transformers", specifier = ">=4.51.0,<4.54.0" },
+ { name = "transformers", specifier = ">=4.55.4" },
{ name = "triton", marker = "sys_platform != 'darwin'", index = "https://download.pytorch.org/whl/cu128" },
{ name = "triton", marker = "sys_platform == 'darwin'", index = "https://pypi.org/simple" },
{ name = "vllm", marker = "extra == 'automodel'", specifier = "==0.10.0" },
@@ -6052,7 +6052,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a1/1d/73ec467d20d96e0bb
[[package]]
name = "transformers"
-version = "4.53.3"
+version = "4.55.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
@@ -6066,9 +6066,9 @@ dependencies = [
{ name = "tokenizers" },
{ name = "tqdm" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/f1/5c/49182918b58eaa0b4c954fd0e37c79fc299e5643e69d70089d0b0eb0cd9b/transformers-4.53.3.tar.gz", hash = "sha256:b2eda1a261de79b78b97f7888fe2005fc0c3fabf5dad33d52cc02983f9f675d8", size = 9197478, upload-time = "2025-07-22T07:30:51.51Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/2b/43/3cb831d5f28cc723516e5bb43a8c6042aca3038bb36b6bd6016b40dfd1e8/transformers-4.55.4.tar.gz", hash = "sha256:574a30559bc273c7a4585599ff28ab6b676e96dc56ffd2025ecfce2fd0ab915d", size = 9573015, upload-time = "2025-08-22T15:18:43.192Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382, upload-time = "2025-07-22T07:30:48.458Z" },
+ { url = "https://files.pythonhosted.org/packages/fa/0a/8791a6ee0529c45f669566969e99b75e2ab20eb0bfee8794ce295c18bdad/transformers-4.55.4-py3-none-any.whl", hash = "sha256:df28f3849665faba4af5106f0db4510323277c4bb595055340544f7e59d06458", size = 11269659, upload-time = "2025-08-22T15:18:40.025Z" },
]
[[package]]