Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"packaging",
"tensorboard>=2.19.0",
"torch",
"transformers>=4.51.3",
"transformers>=4.55.0",
"typing-extensions",
"rich",
"wandb>=0.19.10",
Expand Down
183 changes: 183 additions & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# GRPO Algorithm Configuration
defaults: "vlm_grpo_3B.yaml"

policy:
model_name: "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 128
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 4
max_total_sequence_length: 2048
precision: "bfloat16"

dtensor_cfg:
enabled: false

# See docs/design-docs/sequence-packing-and-dynamic-batching.md
# for more details on dynamic batching and sequence packing.
#
# We disable dynamic batching for Megatron as it is incompatible with Pipeline parallelism.
# Instead, we use sequence packing.
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

max_grad_norm: 1.0
# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}

optimizer: null # remove default FSDP optimizer

megatron_cfg:
enabled: true
empty_unused_memory_level: 0
activation_checkpointing: false
converter_type: "Qwen2ForCausalLM"
tensor_model_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
sequence_parallel: false
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True

optimizer:
optimizer: "adam"
lr: 2.0e-7
min_lr: 2.0e-7
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 50
lr_warmup_init: 2.0e-8

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: false
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

generation:
backend: "vllm"
# max_new_tokens: ${policy.max_total_sequence_length}
max_new_tokens: 1024
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447.
precision: ${policy.precision}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enable_expert_parallel: false
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
enabled: true
# only relevant when enabled is false
resources:
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
num_nodes: null # Decides number of nodes to be dedicated to generation

data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
prompt_file: "examples/prompts/clevr_cogent_cot.txt"
system_prompt_file: null
dataset_name: "clevr-cogent"
split: "trainA"
shuffle: true

env:
clevr-cogent:
num_workers: 8
reward_functions:
- name: format
weight: 0.2
- name: exact_alnum
weight: 0.8
geometry3k:
num_workers: 8
reward_functions:
- name: format
weight: 0.1
- name: math_expr
weight: 0.9
refcoco:
num_workers: 8
reward_functions:
- name: format
weight: 0.1
- name: bbox_giou
weight: 0.9
kwargs:
giou_penalty_thres: 0.5

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: true
mlflow_enabled: false # Disable MLflow logging
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "grpo-dev"
name: "vlm-grpo-3b-megatron"
tensorboard: {}
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)


cluster:
gpus_per_node: 2
num_nodes: 1
2 changes: 1 addition & 1 deletion nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def from_batches(
item for sublist in list_of_tensors for item in sublist
]
elif isinstance(list_of_tensors[0], PackedTensor):
tensor_or_list = PackedTensor.flattened_concat(list_of_tensors)
tensor_or_list = PackedTensor.concat(list_of_tensors)
elif all(x.ndim == 1 for x in list_of_tensors):
tensor_or_list = torch.cat(list_of_tensors)
elif isinstance(list_of_tensors[0], torch.Tensor):
Expand Down
8 changes: 5 additions & 3 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ def forward_step_arbitrary_loss(
pad_mask_loss=False,
)

multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids_cp_sharded.device)
with straggler_timer:
output_tensor = model(
input_ids_cp_sharded,
position_ids,
attention_mask,
input_ids=input_ids_cp_sharded,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
**multimodal_data,
)

# Apply temperature scaling to logits for training
Expand Down
14 changes: 11 additions & 3 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def freeze_moe_router(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
Expand All @@ -264,6 +267,9 @@ def re_enable_float32_expert_bias(megatron_model):
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router._maintain_float32_expert_bias()
Expand Down Expand Up @@ -1178,11 +1184,13 @@ def forward_step_fn(
packed_seq_params = None
unpacked_input_ids = input_ids

multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids.device)
output_tensor = model(
input_ids_cp_sharded,
position_ids,
attention_mask,
input_ids=input_ids_cp_sharded,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
**multimodal_data,
)

# Apply temperature scaling to logits for training
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ dependencies = [
"triton",
"colored==2.2.3",
"ray[default]==2.46.0",
# transformers==4.54.0/4.54.1 both fail on rm models
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
"transformers>=4.51.0,<4.54.0",
"transformers==4.55.4",
"wandb",
"numpy",
"datasets>=4.0.0",
Expand Down
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading