Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

data:
Expand Down
1 change: 0 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
1 change: 0 additions & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
53 changes: 53 additions & 0 deletions examples/configs/recipes/llm/sft-qwen2.5-math-7b-megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defaults: ../../sft.yaml
sft:
max_num_steps: 80
checkpointing:
enabled: false
policy:
model_name: Qwen/Qwen2.5-Math-7B
train_global_batch_size: 512
max_total_sequence_length: 16384
dtensor_cfg:
enabled: false
megatron_cfg:
enabled: true
tensor_model_parallel_size: 4
context_parallel_size: 2
sequence_parallel: true
freeze_moe_router: true
moe_router_dtype: fp64
moe_router_bias_update_rate: 0.0
moe_permute_fusion: true
optimizer:
lr: 1.0e-06
min_lr: 1.0e-06
bf16: true
adam_beta2: 0.999
adam_eps: 1.0e-08
use_distributed_optimizer: false
use_precision_aware_optimizer: false
scheduler:
lr_decay_iters: null
lr_warmup_iters: 10
lr_warmup_init: 1.0e-11
sequence_packing:
enabled: true
make_sequence_length_divisible_by: 32
data:
dataset_name: openmathinstruct2
prompt_file: examples/prompts/math.txt
split: train_1M
add_generation_prompt: true
output_key: generated_solution
num_workers: 8
logger:
wandb:
project: nemo-rl
name: sft-qwen2.5-math-7b-megatron
tensorboard:
log_dir: tb_logs-sft-qwen2.5-math-7b-megatron
mlflow:
run_name: sft-qwen2.5-math-7b-megatron
cluster:
gpus_per_node: 8
num_nodes: 2
1 change: 0 additions & 1 deletion examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"


Expand Down
1 change: 0 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"
use_custom_fsdp: false

Expand Down
1 change: 0 additions & 1 deletion examples/configs/sft_openmathinstruct2_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ policy:
activation_checkpointing: false
context_parallel_size: 1
distributed_data_parallel_config:
average_in_collective: true
data_parallel_sharding_strategy: optim_grads_params
grad_reduce_in_fp32: true
overlap_grad_reduce: true
Expand Down
1 change: 0 additions & 1 deletion examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
1 change: 0 additions & 1 deletion examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ policy:
grad_reduce_in_fp32: false
overlap_grad_reduce: false
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: optim_grads_params
data:
Expand Down
65 changes: 62 additions & 3 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,21 @@ def __init__(
"https://github.com/NVIDIA-NeMo/RL/blob/bccbc377705a81a1f4b3c31ad9767bcc15f735a8/nemo_rl/algorithms/sft.py#L175-L179."
)

## These settings are required for correct gradient computations in mcore
## when calculate_per_token_loss is True, there is no scaling of the gradient in mcore,
## so we handle the scaling in nemo-rl.
## perform_initialization = True is a workaround to ensure the correct tensor parallel attributes are set
## on the TP-sharded parameters.
model_cfg.calculate_per_token_loss = True
model_cfg.perform_initialization = True

assert (
"aux_loss" not in model_cfg.moe_router_load_balancing_type
or model_cfg.moe_aux_loss_coeff == 0
), (
"MoE aux loss is currently not supported due to a known but in Megatron-LM. See ## TODO: link to GH issue"
)

Comment thread
ashors1 marked this conversation as resolved.
self.megatron_cfg = ConfigContainer(
model=model_cfg,
checkpoint=checkpoint_config,
Expand All @@ -689,9 +704,9 @@ def __init__(
overlap_param_gather=self.cfg["megatron_cfg"][
"distributed_data_parallel_config"
]["overlap_param_gather"],
average_in_collective=self.cfg["megatron_cfg"][
"distributed_data_parallel_config"
]["average_in_collective"],
# we need to set average_in_collective=False with calculate_per_token_loss=True.
# otherwise, mcore throws an assertion error.
average_in_collective=False,
use_distributed_optimizer=self.cfg["megatron_cfg"]["optimizer"][
"use_distributed_optimizer"
],
Expand Down Expand Up @@ -2231,3 +2246,47 @@ def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
ip = ray._private.services.get_node_ip_address()
gpu_id = ray.get_gpu_ids()[0]
return (ip, gpu_id)

def check_tensor_parallel_attributes(self) -> dict[str, Any]:
"""Check tensor parallel attributes on model parameters.

Returns:
Dictionary containing information about tensor parallel parameters:
- tp_params: List of parameter names that have tensor_model_parallel=True
- non_tp_params: List of parameter names that have tensor_model_parallel=False
- total_params: Total number of parameters checked
- tp_size: Tensor parallel size from config
"""
tp_params = []
non_tp_params = []
total_params = 0

for name, param in self.model.named_parameters():
total_params += 1
tensor_model_parallel = getattr(param, "tensor_model_parallel", False)

if tensor_model_parallel:
tp_params.append(
{
"name": name,
"tensor_model_parallel": tensor_model_parallel,
"partition_dim": getattr(param, "partition_dim", None),
"partition_stride": getattr(param, "partition_stride", None),
"shape": list(param.shape),
}
)
else:
non_tp_params.append(
{
"name": name,
"tensor_model_parallel": tensor_model_parallel,
"shape": list(param.shape),
}
)

return {
"tp_params": tp_params,
"non_tp_params": non_tp_params,
"total_params": total_params,
"tp_size": self.megatron_cfg.model.tensor_model_parallel_size,
}
43 changes: 43 additions & 0 deletions tests/test_suites/llm/sft-qwen2.5-math-7b-megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# TODO: this config can crash on OOM
# https://github.com/NVIDIA-NeMo/RL/issues/263

# ===== BEGIN CONFIG =====
NUM_NODES=2
STEPS_PER_RUN=80 # step_time ~ 29sec
MAX_STEPS=80
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=30
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_sft.py \
--config $CONFIG_PATH \
sft.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
~policy.tokenizer.chat_template \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["80"] < 0.301' \
'data["validation/val_loss"]["80"] < 0.304'
fi
2 changes: 2 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ tests/test_suites/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.sh
tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron.sh
# sequence packing
tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh
# validate TP/DP
tests/test_suites/llm/sft-qwen2.5-math-7b-megatron.sh

#######
# DPO #
Expand Down
Loading
Loading