Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d9c1dde
init commit
zpqiu Oct 9, 2025
a4c459a
fix CP all gather bug
zpqiu Oct 10, 2025
2fcddc1
fix PP bug
zpqiu Oct 11, 2025
9c9d4c4
add unit tests for get topk logits
zpqiu Oct 11, 2025
8e3fb12
add functionality test
zpqiu Oct 11, 2025
f48da95
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 11, 2025
f98dfe2
increase nightly compute time limitation
zpqiu Oct 11, 2025
5a36a60
add missing keys
zpqiu Oct 11, 2025
1b0c8cb
fix config bugs; add l1 test; update readme
zpqiu Oct 11, 2025
1503f41
ci: change teacher model name
zpqiu Oct 12, 2025
3a08782
remove megatron assert
zpqiu Oct 12, 2025
2773293
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 12, 2025
f60c4cd
Update nemo_rl/models/policy/megatron_policy_worker.py
zpqiu Oct 12, 2025
7dc9b03
remove redundant code
zpqiu Oct 15, 2025
18ab6a1
support multi-epoch; update distillation config; adjust policy initil…
zpqiu Oct 16, 2025
03fb2f9
align with grpo
zpqiu Oct 16, 2025
4b104c7
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 16, 2025
266bf29
resolve conflict
zpqiu Oct 16, 2025
0cc8fdc
Update nemo_rl/algorithms/distillation.py
zpqiu Oct 16, 2025
1e4099e
correct calculation of total iter num
zpqiu Oct 16, 2025
208ca14
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 16, 2025
a55958f
cleanup compatibility code
zpqiu Oct 17, 2025
9e9cb0d
fix typo
zpqiu Oct 17, 2025
17f2f58
Update distillation.py
zpqiu Oct 17, 2025
beebd39
fix missing epoch config
zpqiu Oct 17, 2025
74232a2
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 17, 2025
20f719f
Merge branch 'main' into feat-distillation-mcore
zpqiu Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,6 @@ uv run python examples/run_grpo_sliding_puzzle.py

We provide an example on-policy distillation experiment using the [DeepScaler dataset](https://huggingface.co/agentica-org/DeepScaleR-1.5B-Preview).

> [!NOTE]
> Distillation currently supports the DTensor and vLLM generation backend. Megatron generation/training paths are not supported yet.

### On-policy Distillation Single Node

To run on-policy distillation on a single GPU using `Qwen/Qwen3-1.7B-Base` as the student and `Qwen/Qwen3-4B` as the teacher:
Expand Down
68 changes: 67 additions & 1 deletion examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ distillation:
num_generations_per_prompt: 1
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
max_num_steps: 1000
max_num_epochs: 10
val_batch_size: 64
val_period: 20
val_at_start: false
Expand Down Expand Up @@ -80,8 +81,73 @@ policy: &POLICY_BASE
foreach: False
fused: False

megatron_cfg: # [TODO]
megatron_cfg: &MEGATRON_BASE
enabled: false
empty_unused_memory_level: 0
activation_checkpointing: false
converter_type: "Qwen3ForCausalLM"
tensor_model_parallel_size: 2
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 2
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 2
pipeline_dtype: ${policy.precision}
sequence_parallel: false
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True
bias_activation_fusion: True
defer_fp32_logits: null

optimizer:
optimizer: "adam"
lr: 2.00001e-5
min_lr: 2.0e-5
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

# optimizer cpu offload
optimizer_cpu_offload: false
optimizer_offload_fraction: 0.0

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 10
lr_warmup_init: 2.0e-6

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand Down
158 changes: 158 additions & 0 deletions examples/configs/distillation_math_megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
defaults: distillation_math.yaml

checkpointing:
checkpoint_dir: "checkpoints/distillation-megatron-${policy.model_name}"

policy: &POLICY_BASE
model_name: "Qwen/Qwen3-1.7B-Base"
tokenizer:
name: ${..model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 64
train_micro_batch_size: 1
generation_batch_size: 64
logprob_batch_size: 1
max_total_sequence_length: 8192
precision: "bfloat16"
logprob_chunk_size: null

dtensor_cfg:
enabled: false

dynamic_batching:
enabled: false
train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}}
logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: true
train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}}
logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

max_grad_norm: 1.0

make_sequence_length_divisible_by: ${mul:${mul:${.megatron_cfg.tensor_model_parallel_size}, ${.megatron_cfg.context_parallel_size}}, 2}

megatron_cfg: &MEGATRON_BASE
enabled: true
empty_unused_memory_level: 0
activation_checkpointing: false
converter_type: "Qwen3ForCausalLM"
tensor_model_parallel_size: 2
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 2
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 2
pipeline_dtype: ${policy.precision}
sequence_parallel: false
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True
bias_activation_fusion: True
defer_fp32_logits: null

optimizer:
optimizer: "adam"
lr: 2.00001e-5
min_lr: 2.0e-5
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

# optimizer cpu offload
optimizer_cpu_offload: false
optimizer_offload_fraction: 0.0

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 10
lr_warmup_init: 2.0e-6

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

generation:
backend: "vllm"
max_new_tokens: ${..max_total_sequence_length} # refer to local policy/teacher config
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: ${...precision}
tensor_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP
gpu_memory_utilization: 0.6
max_model_len: ${...max_total_sequence_length} # refer to local policy/teacher config
enforce_eager: False
use_deep_gemm: False
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
distributed_executor_backend: null

colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
enabled: true
# only relevant when enabled is false
resources:
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
num_nodes: null # Decides number of nodes to be dedicated to generation

teacher:
<<: *POLICY_BASE
model_name: "Qwen/Qwen3-4B"
megatron_cfg:
<<: *MEGATRON_BASE
context_parallel_size: 2
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 2

logger:
wandb_enabled: true
wandb:
project: "nemo-distillation"
name: "distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
tensorboard:
log_dir: "tb_logs-distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
mlflow:
run_name: "distillation-math-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"

cluster:
gpus_per_node: 8
num_nodes: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
defaults: ../../distillation_math.yaml
distillation:
num_prompts_per_step: 32
max_num_steps: 20
val_batch_size: 32
val_period: 10
max_val_samples: 256
loss_fn:
kl_type: reverse
checkpointing:
checkpoint_dir: checkpoints/distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack
policy:
train_global_batch_size: 32
generation_batch_size: 32
dtensor_cfg:
enabled: false
dynamic_batching:
enabled: false
sequence_packing:
enabled: true
make_sequence_length_divisible_by: ${mul:${mul:${.megatron_cfg.tensor_model_parallel_size},
${.megatron_cfg.context_parallel_size}}, 2}
megatron_cfg:
enabled: true
teacher:
model_name: Qwen/Qwen3-32B
dtensor_cfg:
enabled: false
dynamic_batching:
enabled: false
sequence_packing:
enabled: true
megatron_cfg:
enabled: true
tensor_model_parallel_size: 4
context_parallel_size: 1
logger:
log_dir: logs/distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack
wandb:
project: nemo-rl
name: distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack
Loading
Loading