Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"packaging",
"tensorboard>=2.19.0",
"torch",
"transformers>=4.51.3",
"transformers>=4.55.0",
"typing-extensions",
"rich",
"wandb>=0.19.10",
Expand Down
183 changes: 183 additions & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# GRPO Algorithm Configuration
defaults: "vlm_grpo_3B.yaml"

policy:
model_name: "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 128
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 4
max_total_sequence_length: 2048
precision: "bfloat16"

dtensor_cfg:
enabled: false

# See docs/design-docs/sequence-packing-and-dynamic-batching.md
# for more details on dynamic batching and sequence packing.
#
# We disable dynamic batching for Megatron as it is incompatible with Pipeline parallelism.
# Instead, we use sequence packing.
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

max_grad_norm: 1.0
# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}

optimizer: null # remove default FSDP optimizer

megatron_cfg:
enabled: true
empty_unused_memory_level: 0
activation_checkpointing: false
converter_type: "Qwen2ForCausalLM"
tensor_model_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
sequence_parallel: false
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True

optimizer:
optimizer: "adam"
lr: 2.0e-7
min_lr: 2.0e-7
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 50
lr_warmup_init: 2.0e-8

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: false
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

generation:
backend: "vllm"
# max_new_tokens: ${policy.max_total_sequence_length}
max_new_tokens: 1024
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447.
precision: ${policy.precision}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enable_expert_parallel: false
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
enabled: true
# only relevant when enabled is false
resources:
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
num_nodes: null # Decides number of nodes to be dedicated to generation

data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
prompt_file: "examples/prompts/clevr_cogent_cot.txt"
system_prompt_file: null
dataset_name: "clevr-cogent"
split: "trainA"
shuffle: true

env:
clevr-cogent:
num_workers: 8
reward_functions:
- name: format
weight: 0.2
- name: exact_alnum
weight: 0.8
geometry3k:
num_workers: 8
reward_functions:
- name: format
weight: 0.1
- name: math_expr
weight: 0.9
refcoco:
num_workers: 8
reward_functions:
- name: format
weight: 0.1
- name: bbox_giou
weight: 0.9
kwargs:
giou_penalty_thres: 0.5

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: true
mlflow_enabled: false # Disable MLflow logging
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "grpo-dev"
name: "vlm-grpo-3b-megatron"
tensorboard: {}
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)


cluster:
gpus_per_node: 2
num_nodes: 1
26 changes: 18 additions & 8 deletions examples/run_vlm_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,29 @@ def hf_data_processor(

length = sum(len(m["token_ids"]) for m in message_log)
loss_multiplier = 1.0
if length > max_seq_length:
if length >= max_seq_length:
# Treat truncated messages as text only
vllm_kwargs = {
"vllm_content": None,
"vllm_images": [],
}

# make smaller and mask out
for chat_message in message_log:
chat_message["token_ids"] = chat_message["token_ids"][
: min(4, max_seq_length // len(message_log))
]
for key, value in chat_message.items():
if isinstance(value, PackedTensor):
chat_message[key] = PackedTensor.empty_like(value)
loss_multiplier = 0.0
raise NotImplementedError(
"Sequence length is too long, please use a shorter sequence length"
)
else:
# get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
# add images for vllm serving
vllm_kwargs = {
"vllm_content": string_formatted_dialog,
"vllm_images": images,
}

output: DatumSpec = {
"message_log": message_log,
Expand All @@ -218,10 +231,7 @@ def hf_data_processor(
"loss_multiplier": loss_multiplier,
"idx": idx,
"task_name": task_data_spec.task_name,
# get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
# add images for vllm serving
"vllm_content": string_formatted_dialog,
"vllm_images": images,
**vllm_kwargs,
}
return output

Expand Down
30 changes: 23 additions & 7 deletions nemo_rl/data/multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,57 @@ class PackedTensor:
"""

def __init__(
self, tensors: Union[torch.Tensor, list[torch.Tensor]], dim_to_pack: int
self,
tensors: Union[torch.Tensor, list[Optional[torch.Tensor]], list[None]],
dim_to_pack: int,
) -> None:
assert tensors is not None, "Input tensors to PackedTensor cannot be None"

if isinstance(tensors, torch.Tensor):
self.tensors: list[torch.Tensor] = [tensors]
self.tensors: list[Optional[torch.Tensor]] = [tensors]
elif isinstance(tensors, list):
assert len(tensors) > 0, (
"Input tensors to PackedTensor must be a non-empty list"
)
self.tensors: list[torch.Tensor] = tensors
self.tensors: list[Optional[torch.Tensor]] = tensors
else:
raise ValueError(
f"Unsupported type for input tensors to PackedTensor: {type(tensors)}"
)
self.dim_to_pack = dim_to_pack

def as_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor:
def as_tensor(self, device: Optional[torch.device] = None) -> Optional[torch.Tensor]:
if device is not None:
self.tensors = [item.to(device) for item in self.tensors]
return torch.cat(self.tensors, dim=self.dim_to_pack).to(device)
# Move only non-None tensors to device, preserve Nones
for i, item in enumerate(self.tensors):
if item is not None:
self.tensors[i] = item.to(device)
non_none_tensors = [t for t in self.tensors if t is not None]
if len(non_none_tensors) == 0:
return None
else:
return torch.cat(non_none_tensors, dim=self.dim_to_pack).to(device)

def __len__(self) -> int:
# this is the number of tensors in this data wrapper
return len(self.tensors)

def to(self, device: str | torch.device) -> "PackedTensor":
self.tensors = [item.to(device) for item in self.tensors]
self.tensors = [
item.to(device) if item is not None else None for item in self.tensors
]
return self

def slice(self, indices: Union[list[int], torch.Tensor]) -> "PackedTensor":
idx = indices.tolist() if isinstance(indices, torch.Tensor) else indices
tensors = [self.tensors[i] for i in idx]
return PackedTensor(tensors, self.dim_to_pack)

@classmethod
def empty_like(cls, other: "PackedTensor") -> "PackedTensor":
"""Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None."""
return cls([None] * len(other.tensors), other.dim_to_pack)

@classmethod
def concat(cls, from_packed_tensors: list["PackedTensor"]) -> "PackedTensor":
"""Concatenate a list of PackedTensor objects into a single PackedTensor.
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def from_batches(
item for sublist in list_of_tensors for item in sublist
]
elif isinstance(list_of_tensors[0], PackedTensor):
tensor_or_list = PackedTensor.flattened_concat(list_of_tensors)
tensor_or_list = PackedTensor.concat(list_of_tensors)
elif all(x.ndim == 1 for x in list_of_tensors):
tensor_or_list = torch.cat(list_of_tensors)
elif isinstance(list_of_tensors[0], torch.Tensor):
Expand Down
5 changes: 4 additions & 1 deletion nemo_rl/models/generation/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def _get_regular_prompt(index: int):
prompt_dict = {"prompt": msg}
# add additional data if present
images = data.get("vllm_images", None)
if images is not None:
if images is None or len(images[i]) == 0:
prompts.append(_get_regular_prompt(i))
continue
else:
prompt_dict["multi_modal_data"] = {
"image": images[i][0] if len(images[i]) == 1 else images[i]
}
Expand Down
8 changes: 5 additions & 3 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ def forward_step_arbitrary_loss(
pad_mask_loss=False,
)

multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids_cp_sharded.device)
with straggler_timer:
output_tensor = model(
input_ids_cp_sharded,
position_ids,
attention_mask,
input_ids=input_ids_cp_sharded,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
**multimodal_data,
)

# Apply temperature scaling to logits for training
Expand Down
14 changes: 11 additions & 3 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def freeze_moe_router(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
Expand All @@ -264,6 +267,9 @@ def re_enable_float32_expert_bias(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router._maintain_float32_expert_bias()
Expand Down Expand Up @@ -1178,11 +1184,13 @@ def forward_step_fn(
packed_seq_params = None
unpacked_input_ids = input_ids

multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids.device)
output_tensor = model(
input_ids_cp_sharded,
position_ids,
attention_mask,
input_ids=input_ids_cp_sharded,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
**multimodal_data,
)

# Apply temperature scaling to logits for training
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ dependencies = [
"triton",
"colored==2.2.3",
"ray[default]==2.46.0",
# transformers==4.54.0/4.54.1 both fail on rm models
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
"transformers>=4.51.0,<4.54.0",
"transformers==4.55.4",
"wandb",
"numpy",
"datasets>=4.0.0",
Expand Down
Loading
Loading