diff --git a/README.md b/README.md index 62d1706c48..e37685c828 100644 --- a/README.md +++ b/README.md @@ -327,9 +327,6 @@ uv run python examples/run_grpo_sliding_puzzle.py We provide an example on-policy distillation experiment using the [DeepScaler dataset](https://huggingface.co/agentica-org/DeepScaleR-1.5B-Preview). -> [!NOTE] -> Distillation currently supports the DTensor and vLLM generation backend. Megatron generation/training paths are not supported yet. - ### On-policy Distillation Single Node To run on-policy distillation on a single GPU using `Qwen/Qwen3-1.7B-Base` as the student and `Qwen/Qwen3-4B` as the teacher: diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index d007e78be6..e0b8bcf283 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -4,6 +4,7 @@ distillation: num_generations_per_prompt: 1 max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) max_num_steps: 1000 + max_num_epochs: 10 val_batch_size: 64 val_period: 20 val_at_start: false @@ -80,8 +81,73 @@ policy: &POLICY_BASE foreach: False fused: False - megatron_cfg: # [TODO] + megatron_cfg: &MEGATRON_BASE enabled: false + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen3ForCausalLM" + tensor_model_parallel_size: 2 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 2 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 2 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + bias_activation_fusion: True + defer_fp32_logits: null + + optimizer: + optimizer: "adam" + lr: 2.00001e-5 + min_lr: 2.0e-5 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 10 + lr_warmup_init: 2.0e-6 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/distillation_math_megatron.yaml b/examples/configs/distillation_math_megatron.yaml new file mode 100644 index 0000000000..3df59eba84 --- /dev/null +++ b/examples/configs/distillation_math_megatron.yaml @@ -0,0 +1,158 @@ +defaults: distillation_math.yaml + +checkpointing: + checkpoint_dir: "checkpoints/distillation-megatron-${policy.model_name}" + +policy: &POLICY_BASE + model_name: "Qwen/Qwen3-1.7B-Base" + tokenizer: + name: ${..model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 64 + train_micro_batch_size: 1 + generation_batch_size: 64 + logprob_batch_size: 1 + max_total_sequence_length: 8192 + precision: "bfloat16" + logprob_chunk_size: null + + dtensor_cfg: + enabled: false + + dynamic_batching: + enabled: false + train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}} + logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: true + train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}} + logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + max_grad_norm: 1.0 + + make_sequence_length_divisible_by: ${mul:${mul:${.megatron_cfg.tensor_model_parallel_size}, ${.megatron_cfg.context_parallel_size}}, 2} + + megatron_cfg: &MEGATRON_BASE + enabled: true + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen3ForCausalLM" + tensor_model_parallel_size: 2 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 2 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 2 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + bias_activation_fusion: True + defer_fp32_logits: null + + optimizer: + optimizer: "adam" + lr: 2.00001e-5 + min_lr: 2.0e-5 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 10 + lr_warmup_init: 2.0e-6 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + generation: + backend: "vllm" + max_new_tokens: ${..max_total_sequence_length} # refer to local policy/teacher config + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${...precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP + gpu_memory_utilization: 0.6 + max_model_len: ${...max_total_sequence_length} # refer to local policy/teacher config + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +teacher: + <<: *POLICY_BASE + model_name: "Qwen/Qwen3-4B" + megatron_cfg: + <<: *MEGATRON_BASE + context_parallel_size: 2 + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 2 + +logger: + wandb_enabled: true + wandb: + project: "nemo-distillation" + name: "distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + tensorboard: + log_dir: "tb_logs-distillation-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + mlflow: + run_name: "distillation-math-megatron-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}" + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.yaml b/examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.yaml new file mode 100644 index 0000000000..6fda3fe24e --- /dev/null +++ b/examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.yaml @@ -0,0 +1,41 @@ +defaults: ../../distillation_math.yaml +distillation: + num_prompts_per_step: 32 + max_num_steps: 20 + val_batch_size: 32 + val_period: 10 + max_val_samples: 256 +loss_fn: + kl_type: reverse +checkpointing: + checkpoint_dir: checkpoints/distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack +policy: + train_global_batch_size: 32 + generation_batch_size: 32 + dtensor_cfg: + enabled: false + dynamic_batching: + enabled: false + sequence_packing: + enabled: true + make_sequence_length_divisible_by: ${mul:${mul:${.megatron_cfg.tensor_model_parallel_size}, + ${.megatron_cfg.context_parallel_size}}, 2} + megatron_cfg: + enabled: true +teacher: + model_name: Qwen/Qwen3-32B + dtensor_cfg: + enabled: false + dynamic_batching: + enabled: false + sequence_packing: + enabled: true + megatron_cfg: + enabled: true + tensor_model_parallel_size: 4 + context_parallel_size: 1 +logger: + log_dir: logs/distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack + wandb: + project: nemo-rl + name: distillation-qwen3-32b-to-1.7b-base-megatron-tp2pp2cp2-pack diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 7f9201f91e..dada437bae 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -75,7 +75,8 @@ class DistillationConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int max_rollout_turns: int # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) - max_num_steps: int + max_num_steps: int # maximum number of steps to train for + max_num_epochs: int # maximum number of epochs to train for val_batch_size: int val_period: int val_at_start: bool @@ -85,7 +86,9 @@ class DistillationConfig(TypedDict): class DistillationSaveState(TypedDict): - step: int + total_steps: int # Track total number of steps across all epochs + current_epoch: int # Track current epoch + current_step: int # Track step within current epoch val_reward: NotRequired[ float ] # Can be any metric. Setted to 'accuracy' by default in validation. @@ -95,7 +98,9 @@ class DistillationSaveState(TypedDict): def _default_distillation_save_state() -> DistillationSaveState: return { - "step": 0, + "current_epoch": 0, + "current_step": 0, + "total_steps": 0, "val_reward": -99999999.0, # Aligned with GRPO "consumed_samples": 0, "total_valid_tokens": 0, @@ -185,17 +190,8 @@ def setup( "A generation config in the PolicyConfig is required for distillation" ) - # Disallow Megatron paths (generation/training) and SP + packing for distillation - assert generation_config["backend"] != "megatron", ( - "Distillation does not support Megatron generation backend; please use vLLM." - ) + # Disallow SP + packing for dtensor path for cfg, who in ((policy_config, "student"), (teacher_config, "teacher")): - if "megatron_cfg" in cfg and cfg["megatron_cfg"]["enabled"]: - raise AssertionError( - f"Distillation does not support Megatron training path ({who} policy). " - "Please refer to https://github.com/NVIDIA-NeMo/RL/issues/1151 for more details." - ) - # DTensor sequence parallel is supported; ensure CP and SP are not enabled together # This incompatibility is enforced in DTensor workers during initialization. # Additionally, SP may not be compatible with sequence packing for some models. @@ -254,7 +250,9 @@ def setup( ) dataloader.load_state_dict(dataloader_state_dict) - print(f" ✓ Training dataloader loaded with {len(train_dataset)} samples") + print( + f" ✓ Training dataloader loaded with {len(train_dataset)} samples", flush=True + ) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None @@ -269,12 +267,15 @@ def setup( shuffle=False, collate_fn=rl_collate_fn, ) - print(f" ✓ Validation dataloader loaded with {len(val_dataset)} samples") + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) # ========================== # Cluster # ========================== - print("\n▶ Setting up compute cluster...") + print("\n▶ Setting up compute cluster...", flush=True) colocated_inference = generation_config["colocated"]["enabled"] if colocated_inference: @@ -290,9 +291,15 @@ def setup( ) train_cluster = cluster inference_cluster = cluster - print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + print( + f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes", + flush=True, + ) else: - # We has disallow megatron path for distillation above. + assert generation_config["backend"] != "megatron", ( + "Non-colocated inference is not supported for Megatron generation backends. " + "Please use vLLM backend for generation." + ) # train resources will be updated through overall and inference resources below train_gpus_per_node = cluster_config["gpus_per_node"] @@ -350,37 +357,14 @@ def setup( max_colocated_worker_groups=3, ) print( - f" ✓ Separate clusters created: train={train_nodes}x{train_gpus_per_node}GPUs, inference={inference_nodes}x{inference_gpus_per_node}GPUs" + f" ✓ Separate clusters created: train={train_nodes}x{train_gpus_per_node}GPUs, inference={inference_nodes}x{inference_gpus_per_node}GPUs", + flush=True, ) - # ========================== - # Student Policy - # ========================== - print("\n▶ Setting up student policy...") - - # Checkpoint paths - if last_checkpoint_path: - weights_path = Path(last_checkpoint_path) / "policy" / "weights" - optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" - else: - weights_path = None - optimizer_path = None - - student_policy = Policy( - name_prefix="student", - cluster=train_cluster, - config=policy_config, - tokenizer=tokenizer, - weights_path=weights_path, - optimizer_path=optimizer_path, - init_optimizer=True, - init_reference_model=False, - ) - # ========================== # Teacher Policy # ========================== - print("\n▶ Setting up teacher policy...") + print("\n▶ Setting up teacher policy...", flush=True) # Checkpoint paths weights_path = None optimizer_path = None @@ -390,6 +374,14 @@ def setup( tokenizer, policy_config["model_name"], teacher_config["model_name"] ) + if "megatron_cfg" in teacher_config and teacher_config["megatron_cfg"]["enabled"]: + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), + ) + teacher_config["megatron_cfg"]["train_iters"] = total_train_iters + teacher_policy = Policy( name_prefix="teacher", cluster=train_cluster, @@ -400,9 +392,10 @@ def setup( init_optimizer=False, init_reference_model=False, ) + teacher_policy.offload_after_refit() # ========================== - # Generation Interface + # Student Generation Interface # ========================== backend = generation_config["backend"] generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM @@ -421,8 +414,41 @@ def setup( ) student_generation.finish_generation() print( - f" ✓ Using vLLM backend for generation with {policy_config['model_name']}" + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + flush=True, + ) + + # ========================== + # Student Policy + # ========================== + print("\n▶ Setting up student policy...", flush=True) + + # Checkpoint paths + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + else: + weights_path = None + optimizer_path = None + + if "megatron_cfg" in policy_config and policy_config["megatron_cfg"]["enabled"]: + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + student_policy = Policy( + name_prefix="student", + cluster=train_cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + ) if student_generation is not None: state_dict_info = student_policy.prepare_refit_info() @@ -449,7 +475,7 @@ def setup( print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") - print("=" * 60 + "\n") + print("=" * 60 + "\n", flush=True) return ( student_policy, @@ -501,19 +527,29 @@ def distillation_train( POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running assert student_generation is not None # for mypy type check - # common config/state itmes - step = distillation_save_state["step"] + # common config/state items + current_epoch = distillation_save_state["current_epoch"] # current epoch + current_step = distillation_save_state[ + "current_step" + ] # current step within current epoch + total_steps = distillation_save_state[ + "total_steps" + ] # total number of steps across all epochs consumed_samples = distillation_save_state["consumed_samples"] - total_valid_tokens = distillation_save_state.get( - "total_valid_tokens", 0 - ) # Default to 0 for backward compatibility with older checkpoints + total_valid_tokens = distillation_save_state["total_valid_tokens"] val_period = master_config["distillation"]["val_period"] val_at_start = master_config["distillation"]["val_at_start"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + max_epochs = master_config["distillation"][ + "max_num_epochs" + ] # max number of epochs to train for + max_steps = master_config["distillation"][ + "max_num_steps" + ] # max number of steps to train for # Run validation at the start if configured - if val_at_start and step == 0: - print("\n🔍 Running initial validation...") + if val_at_start and total_steps == 0: + print("\n🔍 Running initial validation...", flush=True) if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( student_policy, student_generation, colocated_inference @@ -526,28 +562,35 @@ def distillation_train( val_dataloader, tokenizer, val_task_to_env, - step=0, + step=total_steps, master_config=master_config, ) student_generation.finish_generation() - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") - # Run distillation training (multi-epoch until reaching max_num_steps) + # Run distillation training (multi-epoch until reaching max_num_steps or max_num_epochs) batch: BatchedDataDict[DatumSpec] - max_steps = master_config["distillation"]["max_num_steps"] - while step < max_steps: + while total_steps < max_steps and current_epoch < max_epochs: + print( + f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_epochs} {'=' * 25}", + flush=True, + ) + for batch in dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{max_steps} {'=' * 25}") - maybe_gpu_profile_step(student_policy, step + 1) + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_steps)} {'=' * 25}", + flush=True, + ) + maybe_gpu_profile_step(student_policy, total_steps + 1) if student_policy != student_generation: - maybe_gpu_profile_step(student_generation, step + 1) + maybe_gpu_profile_step(student_generation, total_steps + 1) val_metrics, validation_timings = None, None with timer.time("total_step_time"): # Prepare batch - print("▶ Preparing batch...") + print("▶ Preparing batch...", flush=True) with timer.time("data_processing"): # Repeat batch items repeated_batch: BatchedDataDict[DatumSpec] = ( @@ -558,7 +601,8 @@ def distillation_train( # Generate responses - this updates the LLMMessageLogType in repeated_batch print( - f"▶ Generating responses for batch of size {repeated_batch.size}..." + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, ) with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: @@ -644,11 +688,11 @@ def distillation_train( ) train_data.to("cpu") - print("▶ Preparing for teacher logprob inference...") + print("▶ Preparing for teacher logprob inference...", flush=True) with timer.time("teacher_logprob_inference_prep"): teacher_policy.prepare_for_lp_inference() - print("▶ Computing teacher logprobs...") + print("▶ Computing teacher logprobs...", flush=True) with timer.time("teacher_logprob_inference"): teacher_topk = teacher_policy.get_topk_logits( train_data, k=master_config["distillation"]["topk_logits_k"] @@ -656,22 +700,23 @@ def distillation_train( train_data["teacher_topk_logits"] = teacher_topk["topk_logits"] train_data["teacher_topk_indices"] = teacher_topk["topk_indices"] - print("▶ Preparing for training...") + print("▶ Preparing for training...", flush=True) with timer.time("training_prep"): teacher_policy.offload_after_refit() student_policy.prepare_for_training() # set model train and reload optim to GPU POLICY_GENERATION_STALE = True - print("▶ Training policy...") + print("▶ Training policy...", flush=True) with timer.time("policy_training"): train_results = student_policy.train(train_data, loss_fn) - is_last_step = ( - step + 1 == master_config["distillation"]["max_num_steps"] + is_last_step = (total_steps + 1 >= max_steps) or ( + (current_epoch + 1 == max_epochs) + and (current_step + 1 == len(dataloader)) ) # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: + if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( student_policy, student_generation, colocated_inference @@ -684,14 +729,16 @@ def distillation_train( val_dataloader, tokenizer, val_task_to_env, - step=step + 1, + step=total_steps + 1, master_config=master_config, ) student_generation.finish_generation() logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") metrics = { "loss": train_results["loss"].numpy(), @@ -722,9 +769,10 @@ def distillation_train( should_save_by_step = ( is_last_step - or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 ) - # +1 because step is 0-indexed + # +1 because total_steps is 0-indexed # Check if timeout-based checkpointing is enabled in config. should_save_by_timeout = timeout.check_save() @@ -733,7 +781,9 @@ def distillation_train( ): student_policy.prepare_for_training() - distillation_save_state["step"] = step + 1 + distillation_save_state["current_epoch"] = current_epoch + distillation_save_state["current_step"] = current_step + 1 + distillation_save_state["total_steps"] = total_steps + 1 distillation_save_state["total_valid_tokens"] = total_valid_tokens if val_metrics is not None: distillation_save_state["val_reward"] = val_metrics["accuracy"] @@ -754,9 +804,12 @@ def distillation_train( master_config["checkpointing"]["metric_name"] = None with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") + print( + f"Saving checkpoint for step {total_steps + 1}...", + flush=True, + ) checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, distillation_save_state, master_config + total_steps + 1, distillation_save_state, master_config ) student_policy.save_checkpoint( weights_path=os.path.join( @@ -780,7 +833,9 @@ def distillation_train( # Log training data log_data = {"content": flat_messages["content"]} log_data["input_lengths"] = input_lengths.tolist() - logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl") + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps}.jsonl" + ) timing_metrics: dict[str, float] = timer.get_timing_metrics( reduction_op="sum" @@ -800,16 +855,18 @@ def distillation_train( ) num_ranks = train_results["num_ranks"] print( - f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + flush=True, ) if "theoretical_tflops" in train_results: theoretical_tflops = train_results["theoretical_tflops"] print( - f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, ) metrics["train_fp_utilization"] = total_tflops / theoretical_tflops - print("\n⏱️ Timing:") + print("\n⏱️ Timing:", flush=True) # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) @@ -825,7 +882,7 @@ def distillation_train( } ) - print(f" • Total step time: {total_time:.2f}s") + print(f" • Total step time: {total_time:.2f}s", flush=True) # Display all other timing metrics for k, v in sorted( @@ -833,26 +890,31 @@ def distillation_train( ): if k != "total_step_time": percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) timing_metrics["valid_tokens_per_sec_per_gpu"] = ( metrics["global_valid_toks"] / total_time / total_num_gpus ) - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") timer.reset() - step += 1 + current_step += 1 + total_steps += 1 if should_save_by_timeout: print("Timeout has been reached, stopping training early", flush=True) return - if step >= max_steps: + if total_steps >= max_steps: print( "Max number of steps has been reached, stopping training early", flush=True, ) return + # End of epoch + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch + def validate( policy_generation: GenerationInterface, @@ -864,18 +926,19 @@ def validate( ) -> tuple[dict[str, Any], dict[str, Any]]: """Run validation on the validation dataset.""" if val_dataloader is None: - print(" ⚠️ No validation dataloader provided, skipping validation") + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) return {}, {} if val_task_to_env is None: print( - " ⚠️ No validation task to environment mapping provided, skipping validation" + " ⚠️ No validation task to environment mapping provided, skipping validation", + flush=True, ) return {}, {} timer = Timer() with timer.time("total_validation_time"): - print(f"▶ Starting validation at step {step}...") + print(f"▶ Starting validation at step {step}...", flush=True) total_rewards = [] # Can be any metric. Setted to 'accuracy' by default. total_lengths = [] @@ -956,7 +1019,7 @@ def validate( ) except Exception as e: print(f"\n ⚠️ Error displaying message samples: {str(e)}") - print(" ⚠️ Continuing validation without displaying samples...") + print(" ⚠️ Continuing validation without displaying samples...", flush=True) # Get timing metrics timing_metrics = timer.get_timing_metrics(reduction_op="sum") @@ -966,12 +1029,12 @@ def validate( print("\n📊 Validation Results:") print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") - print(f" • Samples processed: {len(total_rewards)}") + print(f" • Samples processed: {len(total_rewards)}", flush=True) # Print timing information print("\n ⏱️ Validation Timing:") validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s") + print(f" • Total validation time: {validation_time:.2f}s", flush=True) # Make sure to reset the timer after validation timer.reset() diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a2da9a3981..2686e9c53a 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -100,6 +100,8 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( + allgather_cp_sharded_tensor, + distributed_vocab_topk, from_parallel_logits_to_logprobs, from_parallel_logits_to_logprobs_packed_sequences, ) @@ -1422,9 +1424,289 @@ def get_topk_logits( k: int, micro_batch_size: Optional[int] = None, ): - raise NotImplementedError( - "get_topk_logits (teacher top-k logits for distillation) is not implemented for the Megatron backend yet." - " Track progress in the GitHub issue: https://github.com/NVIDIA-NeMo/RL/issues/1151" + """Get the top-k logits and indices for a batch of data. + + The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence. + + Returns: + BatchedDataDict containing: + - topk_logits: Tensor of top-k logits for each position in the sequence + - topk_indices: Tensor of top-k indices for each position in the sequence + """ + no_grad = torch.no_grad() + no_grad.__enter__() + + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + + sequence_dim = 1 + input_seq_dim_size = data["input_ids"].shape[sequence_dim] + # Avoid shadowing the function argument `k` by using a distinct variable name + for tensor_name, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == input_seq_dim_size, ( + f"Tensor {tensor_name} must have sequence dimension {sequence_dim} of size {input_seq_dim_size}, but got shape {v.shape}" + ) + + self.model.eval() + + pp_seq_dim_size = input_seq_dim_size + pp_grp = get_pipeline_model_parallel_group() + + # If using sequence packing with PP>1, pad full sequence to static PP buffer length + pad_full_seq_to = None + if ( + self.cfg["sequence_packing"]["enabled"] + and self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1 + ): + _, pad_full_seq_to = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + pp_seq_dim_size = pad_full_seq_to + + def forward_step_fn( + data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel + ): + nonlocal pad_full_seq_to + data_dict = next(data_iterator).to("cuda") + + pack = self.cfg["sequence_packing"]["enabled"] + if pack: + original_seq_length = data_dict["input_ids"].shape[1] + tp_size = self.cfg["megatron_cfg"]["tensor_model_parallel_size"] + pp_size = self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + cp_rank = get_context_parallel_rank() + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): + pad_factor = math.lcm(16, pad_factor) + + ( + input_ids_unpacked, + input_ids_cp_sharded, + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + data_dict["input_ids"].clone(), + data_dict["input_lengths"], + pad_individual_seqs_to_multiple_of=pad_factor, + pad_packed_seq_to=pad_full_seq_to, + cp_rank=cp_rank, + cp_size=cp_size, + ) + attention_mask, position_ids = None, None + seq_lengths = data_dict["input_lengths"] + unpacked_seqlen = original_seq_length + else: + input_ids_cp_sharded = data_dict["input_ids"] + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + data=input_ids_cp_sharded, + eod_token=0, + pad_token=0, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + pad_mask_loss=False, + ) + packed_seq_params = None + + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids_cp_sharded.device + ) + if len(multimodal_data) > 0: + position_ids = None + + output_tensor = model( + input_ids=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + **multimodal_data, + ) + + if "generation" in self.cfg and self.cfg["generation"] is not None: + output_tensor.div_(self.cfg["generation"]["temperature"]) + + def collection_fn(_): + # Only the last PP stage produces final logits/top-k; earlier stages return empty + # if not is_pipeline_last_stage(ignore_virtual=True): + # return output_tensor.new_zeros(()), {} + + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + vocab_shard_size = output_tensor.shape[-1] + vocab_start_index = tp_rank * vocab_shard_size + + chunk_size = None + if "logprob_chunk_size" in self.cfg: + chunk_size = self.cfg["logprob_chunk_size"] + + topk_vals_local, topk_idx_local = distributed_vocab_topk( + output_tensor, + k, + tp_grp, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_start_index + vocab_shard_size, + chunk_size=chunk_size, + ) + + if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: + cp_grp = get_context_parallel_group() + if pack: + # Per-sequence CP allgather following packed-sequence logic + batch_size = data_dict["input_ids"].shape[0] + total_packed_len = int(cu_seqlens_padded[-1].item()) + + topk_vals_full = torch.zeros( + (1, total_packed_len, k), + dtype=topk_vals_local.dtype, + device=topk_vals_local.device, + ) + topk_idx_full = torch.zeros( + (1, total_packed_len, k), + dtype=topk_idx_local.dtype, + device=topk_idx_local.device, + ) + + for i in range(batch_size): + start_idx = int(cu_seqlens_padded[i].item()) + end_idx = int(cu_seqlens_padded[i + 1].item()) + if end_idx > start_idx: + local_vals_slice = topk_vals_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + local_idx_slice = topk_idx_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + gathered_vals = allgather_cp_sharded_tensor( + local_vals_slice, cp_grp, seq_dim=1 + ) + gathered_idx = allgather_cp_sharded_tensor( + local_idx_slice, cp_grp, seq_dim=1 + ) + # Some kernels may return [X, Y, k] where X*Y = (end_idx - start_idx). + # Flatten leading dims and reshape to [1, expected_len, k] to match target. + expected_len = end_idx - start_idx + if ( + gathered_vals.dim() == 3 + and gathered_vals.shape[1] != expected_len + ): + gathered_vals = gathered_vals.reshape( + 1, expected_len, gathered_vals.shape[-1] + ) + if ( + gathered_idx.dim() == 3 + and gathered_idx.shape[1] != expected_len + ): + gathered_idx = gathered_idx.reshape( + 1, expected_len, gathered_idx.shape[-1] + ) + topk_vals_full[:, start_idx:end_idx, :] = gathered_vals + topk_idx_full[:, start_idx:end_idx, :] = gathered_idx + else: + # Sequence packing must be enabled when CP > 1 + raise RuntimeError( + "Context Parallelism (CP>1) requires sequence packing to be enabled." + ) + else: + topk_vals_full = topk_vals_local + topk_idx_full = topk_idx_local + + if pack: + batch_size = data_dict["input_ids"].shape[0] + out_vals = torch.zeros( + (batch_size, unpacked_seqlen, k), + dtype=topk_vals_full.dtype, + device=topk_vals_full.device, + ) + out_idx = torch.zeros( + (batch_size, unpacked_seqlen, k), + dtype=topk_idx_full.dtype, + device=topk_idx_full.device, + ) + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + start_idx = int(cu_seqlens_padded[i].item()) + if seq_len > 0: + out_vals[i, :seq_len, :] = topk_vals_full[ + 0, start_idx : start_idx + seq_len, : + ] + out_idx[i, :seq_len, :] = topk_idx_full[ + 0, start_idx : start_idx + seq_len, : + ] + return output_tensor.new_zeros(()), { + "topk_logits": out_vals, + "topk_indices": out_idx, + } + else: + return output_tensor.new_zeros(()), { + "topk_logits": topk_vals_full, + "topk_indices": topk_idx_full, + } + + return output_tensor, collection_fn + + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + micro_batch = logprob_batch_size + elif self.cfg["sequence_packing"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + data_iterator_len, _ = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + micro_batch = 1 + else: + mb_iterator = data.make_microbatch_iterator(logprob_batch_size) + data_iterator_len = max(1, data.size // logprob_batch_size) + micro_batch = logprob_batch_size + + forward_backward_func = get_forward_backward_func() + list_of_outputs = forward_backward_func( + forward_step_func=forward_step_fn, + data_iterator=mb_iterator, + model=self.model, + num_microbatches=data_iterator_len, + seq_length=pp_seq_dim_size, + micro_batch_size=micro_batch, + decoder_seq_length=pp_seq_dim_size, + forward_only=True, + ) + + if is_pipeline_last_stage(ignore_virtual=True): + logits_chunks = [] + indices_chunks = [] + for out in list_of_outputs: + tk = out["topk_logits"] + ti = out["topk_indices"] + pad_len = input_seq_dim_size - tk.shape[1] + if pad_len > 0: + tk = torch.nn.functional.pad(tk, (0, 0, 0, pad_len), value=0.0) + ti = torch.nn.functional.pad(ti, (0, 0, 0, pad_len), value=0) + logits_chunks.append(tk) + indices_chunks.append(ti) + + topk_logits = torch.cat(logits_chunks, dim=0) + topk_indices = torch.cat(indices_chunks, dim=0) + + topk_logits = broadcast_tensor( + topk_logits, torch.distributed.get_rank(), pp_grp + ) + topk_indices = broadcast_tensor( + topk_indices, torch.distributed.get_rank(), pp_grp + ) + else: + last_pp_rank = get_pipeline_model_parallel_last_rank() + topk_logits = broadcast_tensor(None, last_pp_rank, pp_grp) + topk_indices = broadcast_tensor(None, last_pp_rank, pp_grp) + + no_grad.__exit__(None, None, None) + return BatchedDataDict.from_batches( + [{"topk_logits": topk_logits.cpu(), "topk_indices": topk_indices.cpu()}] ) @wrap_with_nvtx_name("megatron_policy_worker/generate") diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 8e7952305e..d6c7134483 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -30,6 +30,7 @@ time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correct time uv run --no-sync bash ./tests/functional/test_automodel_extra_installed_correctly.sh time uv run --no-sync bash ./tests/functional/vlm_grpo.sh time uv run --no-sync bash ./tests/functional/distillation.sh +time uv run --no-sync bash ./tests/functional/distillation_megatron.sh cd /opt/nemo-rl/tests coverage combine .coverage* diff --git a/tests/functional/distillation_megatron.sh b/tests/functional/distillation_megatron.sh new file mode 100644 index 0000000000..b56ea672fb --- /dev/null +++ b/tests/functional/distillation_megatron.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/distillation_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -euo pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_distillation_math.py \ + --config $PROJECT_ROOT/examples/configs/distillation_math_megatron.yaml \ + policy.model_name=Qwen/Qwen3-0.6B-Base \ + teacher.model_name=Qwen/Qwen3-0.6B \ + cluster.gpus_per_node=2 \ + policy.train_global_batch_size=16 \ + policy.megatron_cfg.tensor_model_parallel_size=1 \ + policy.megatron_cfg.pipeline_model_parallel_size=1 \ + policy.megatron_cfg.context_parallel_size=2 \ + policy.max_total_sequence_length=2048 \ + teacher.megatron_cfg.tensor_model_parallel_size=2 \ + teacher.megatron_cfg.pipeline_model_parallel_size=1 \ + teacher.megatron_cfg.context_parallel_size=1 \ + distillation.max_num_steps=3 \ + distillation.num_prompts_per_step=16 \ + distillation.max_val_samples=16 \ + distillation.val_batch_size=8 \ + distillation.val_period=3 \ + data.dataset_name=OpenMathInstruct-2 \ + loss_fn.zero_outside_topk=false \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=true \ + checkpointing.save_period=3 \ + checkpointing.checkpoint_dir=/tmp/distillation_checkpoints \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["3"] < 1.0' diff --git a/tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh b/tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh new file mode 100755 index 0000000000..6710ac87ce --- /dev/null +++ b/tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=10 +MAX_STEPS=10 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=60 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_distillation_math.py \ + --config $CONFIG_PATH \ + distillation.max_num_steps=$MAX_STEPS \ + distillation.val_period=20 \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl-distillation \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 1.5' \ + 'data["train/loss"]["10"] < 0.5' \ + 'max(data["ray/node.0.gpu.0.mem_gb"]) < 75' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 500' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 06e97ff4eb..4a09c6b92d 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -91,3 +91,6 @@ tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.sh # Distillation tests tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh + +# Short megatron +tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index bc1d4b734e..ec8ed37f63 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -124,6 +124,7 @@ def val_iter(self): master_config = { "distillation": { "max_num_steps": 5, + "max_num_epochs": 10, "val_period": 100, "val_batch_size": 1, "val_at_start": False, @@ -511,6 +512,8 @@ def test_distillation_setup_non_colocated_smoke(monkeypatch): "seed": 42, "topk_logits_k": 64, "num_prompts_per_step": 1, + "max_num_epochs": 10, + "max_num_steps": 100, "val_period": 0, "val_at_start": False, }, @@ -546,6 +549,9 @@ def __init__(self, *args, **kwargs): def prepare_refit_info(self): return {} + def offload_after_refit(self): + return None + def init_collective(self, *args, **kwargs): return [MagicMock()] @@ -613,6 +619,8 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): "distillation": { "seed": 42, "topk_logits_k": 64, + "max_num_epochs": 10, + "max_num_steps": 100, "num_prompts_per_step": 1, # Config extraction requires this key "val_period": 0, # Config extraction requires this key "val_at_start": False, # Config extraction requires this key diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 48b2c01dc8..364b08f5f6 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -1321,6 +1321,363 @@ def test_megatron_dpo_training(tiny_llama_model_path): cluster.shutdown() +@pytest.fixture +def topk_setup(request): + """Setup and teardown specifically for top-k logits tests.""" + # Parse parameters: (num_gpus, tp, pp, logprob_chunk_size, defer_fp32_logits, model_fixture_name) + if hasattr(request, "param") and request.param is not None: + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + defer_fp32_logits, + model_fixture_name, + ) = request.param + else: + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + defer_fp32_logits, + model_fixture_name, + ) = (2, 1, 1, None, None, "tiny_llama_model_path") + + # Get the actual model path from the requested fixture + model_name = request.getfixturevalue(model_fixture_name) + + policy = None + cluster = None + data = None + + try: + cluster_name = f"test-megatron-topk-{num_gpus}gpu-tp{tp}-pp{pp}" + print( + f"Creating topk cluster '{cluster_name}' for {num_gpus} GPUs (TP={tp}, PP={pp})" + ) + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + # Determine converter type based on model + converter_type = "LlamaForCausalLM" + if "qwen" in model_name.lower(): + converter_type = "Qwen2ForCausalLM" + elif "gemma" in model_name.lower(): + converter_type = "GemmaForCausalLM" + + config = create_megatron_test_config( + model_name=model_name, + tp=tp, + pp=pp, + converter_type=converter_type, + logprob_chunk_size=logprob_chunk_size, + defer_fp32_logits=defer_fp32_logits, + ) + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) + + print("Creating Megatron topk Policy...") + policy = Policy( + cluster=cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Create test data + print("Creating test batch...") + torch.manual_seed(77) + + input_ids = torch.randint(0, 32000, (4, 64)) # 4 sequences, each of length 64 + attention_mask = torch.ones(4, 64) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + } + ) + + yield policy, cluster, data + + except Exception as e: + print(f"Error during topk setup: {e}") + pytest.skip(f"Topk setup failed: {e}") + finally: + print("Cleaning up topk resources") + if policy: + policy.shutdown() + if cluster: + cluster.shutdown() + + +@pytest.mark.timeout(180) +@pytest.mark.hf_gated +@pytest.mark.parametrize( + "topk_setup", + [ + # (num_gpus, tp, pp, chunk sz, defer fp32, model_fixture_name) + (2, 1, 1, None, None, "tiny_llama_model_path"), + (2, 2, 1, None, None, "tiny_llama_model_path"), + (2, 1, 1, None, None, "tiny_qwen2_model_path"), + (2, 2, 1, None, None, "tiny_qwen2_model_path"), + (2, 1, 1, None, True, "tiny_llama_model_path"), + (2, 2, 1, None, True, "tiny_llama_model_path"), + (2, 1, 1, None, True, "tiny_qwen2_model_path"), + (2, 2, 1, None, True, "tiny_qwen2_model_path"), + (2, 1, 1, 16, True, "tiny_llama_model_path"), + (2, 2, 1, 16, True, "tiny_llama_model_path"), + (2, 1, 1, 16, True, "tiny_qwen2_model_path"), + (2, 2, 1, 16, True, "tiny_qwen2_model_path"), + ], + indirect=True, + ids=[ + "2gpu_dp2_llama", + "2gpu_tp2_llama", + "2gpu_dp2_qwen2", + "2gpu_tp2_qwen2", + "2gpu_dp2_deferfp32_llama", + "2gpu_tp2_deferfp32_llama", + "2gpu_dp2_deferfp32_qwen2", + "2gpu_tp2_deferfp32_qwen2", + "2gpu_dp2_chunked_deferfp32_llama", + "2gpu_tp2_chunked_deferfp32_llama", + "2gpu_dp2_chunked_deferfp32_qwen2", + "2gpu_tp2_chunked_deferfp32_qwen2", + ], +) +def test_megatron_policy_topk_logits(topk_setup): + """Test Megatron policy top-k logits computation.""" + policy, cluster, data = topk_setup + + # Verify resources were created properly + assert policy is not None, "Policy was not created properly" + assert data is not None, "Test data was not created properly" + + # Generate top-k logits + print("\nGenerating top-k logits...") + policy.prepare_for_lp_inference() + k = 5 + outputs = policy.get_topk_logits(data, k=k) + + # Basic validation + assert "topk_logits" in outputs and "topk_indices" in outputs, ( + "Top-k outputs should contain both 'topk_logits' and 'topk_indices'" + ) + topk_logits = outputs["topk_logits"] + topk_indices = outputs["topk_indices"] + + assert isinstance(topk_logits, torch.Tensor) + assert isinstance(topk_indices, torch.Tensor) + assert topk_logits.dtype == torch.float32 + assert topk_indices.dtype in (torch.int32, torch.int64, torch.long) + + # Shape checks + B, S = data.get("input_ids").shape + assert topk_logits.shape == (B, S, k) + assert topk_indices.shape == (B, S, k) + + # Mask invalid positions and check for NaN/Inf + valid_mask = ( + data.get("attention_mask") + .unsqueeze(-1) + .bool() + .expand(-1, -1, topk_logits.shape[-1]) + ) + valid_logits = topk_logits[valid_mask] + assert not torch.isnan(valid_logits).any(), "Top-k logits should not contain NaN" + assert not torch.isinf(valid_logits).any(), "Top-k logits should not contain Inf" + + # Check descending order within top-k for valid positions + if S > 1: + diffs = topk_logits[..., :-1] - topk_logits[..., 1:] + valid_mask_diffs = ( + data.get("attention_mask") + .unsqueeze(-1) + .bool() + .expand(-1, -1, topk_logits.shape[-1] - 1) + ) + diffs = diffs[valid_mask_diffs] + assert (diffs >= -1e-6).all(), "Top-k logits should be non-increasing across k" + + +@pytest.mark.hf_gated +@pytest.mark.timeout(300) +def test_megatron_context_parallel_topk_agreement(tiny_qwen2_model_path): + """Test that CP and non-CP models produce identical top-k logits with sequence packing enabled.""" + num_gpus = 2 + batch_size = 4 + seq_len = 64 + + # Create test data with varying sequence lengths to test sequence packing + torch.manual_seed(123) + input_ids = torch.arange(seq_len * batch_size, device="cuda").reshape( + batch_size, seq_len + ) + input_lengths = torch.tensor([31, 21, 29, 56], dtype=torch.int32) + attention_mask = torch.zeros(batch_size, seq_len) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + } + ) + + k = 5 + + # Test 1: Non-CP model (context_parallel_size=1) with sequence packing + print( + "=== Testing Non-CP model (context_parallel_size=1) with sequence packing for top-k ===" + ) + cluster_no_cp = RayVirtualCluster( + name="test-no-cp-packing-topk", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + config_no_cp = create_megatron_test_config( + tiny_qwen2_model_path, tp=1, pp=1, precision="bfloat16" + ) + # Ensure context parallel is disabled + config_no_cp["megatron_cfg"]["context_parallel_size"] = 1 + + # Enable sequence packing + config_no_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + + tokenizer = get_tokenizer(config_no_cp["tokenizer"]) + config_no_cp["generation"] = configure_generation_config( + config_no_cp["generation"], tokenizer + ) + + policy_no_cp = Policy( + cluster=cluster_no_cp, + config=config_no_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Get top-k from non-CP model with sequence packing + policy_no_cp.prepare_for_lp_inference() + out_no_cp = policy_no_cp.get_topk_logits(data, k=k) + logits_no_cp = out_no_cp["topk_logits"] * attention_mask.unsqueeze(-1) + indices_no_cp = out_no_cp["topk_indices"] + print(f"Non-CP topk logits shape: {logits_no_cp.shape}") + + # Cleanup non-CP resources and run without packing + policy_no_cp.shutdown() + config_no_cp_no_packing = config_no_cp.copy() + config_no_cp_no_packing["sequence_packing"] = {"enabled": False} + policy_no_cp_no_packing = Policy( + cluster=cluster_no_cp, + config=config_no_cp_no_packing, + tokenizer=tokenizer, + init_reference_model=False, + ) + policy_no_cp_no_packing.prepare_for_lp_inference() + out_no_cp_np = policy_no_cp_no_packing.get_topk_logits(data, k=k) + logits_no_cp_np = out_no_cp_np["topk_logits"] * attention_mask.unsqueeze(-1) + indices_no_cp_np = out_no_cp_np["topk_indices"] + print(f"Non-CP (no packing) topk logits shape: {logits_no_cp_np.shape}") + cluster_no_cp.shutdown() + + # Compare non-CP packing vs non-packing + print("=== Comparing non-CP packing vs non-packing top-k ===") + assert logits_no_cp.shape == logits_no_cp_np.shape + assert indices_no_cp.shape == indices_no_cp_np.shape + torch.testing.assert_close(logits_no_cp, logits_no_cp_np, rtol=1e-3, atol=1e-2) + valid_mask = ( + attention_mask.bool().unsqueeze(-1).expand(-1, -1, indices_no_cp.shape[-1]) + ) + assert torch.equal(indices_no_cp[valid_mask], indices_no_cp_np[valid_mask]), ( + "Top-k indices should match between packing and non-packing" + ) + + # Test 2: CP model (context_parallel_size=2) with sequence packing + print( + "=== Testing CP model (context_parallel_size=2) with sequence packing for top-k ===" + ) + cluster_cp = RayVirtualCluster( + name="test-cp-packing-topk", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + + config_cp = create_megatron_test_config( + tiny_qwen2_model_path, tp=1, pp=1, precision="bfloat16" + ) + # Enable context parallel + config_cp["megatron_cfg"]["context_parallel_size"] = 2 + + # Enable sequence packing + config_cp["sequence_packing"] = { + "enabled": True, + "train_mb_tokens": seq_len, + "logprob_mb_tokens": seq_len, + "algorithm": "modified_first_fit_decreasing", + } + config_cp["generation"] = configure_generation_config( + config_cp["generation"], tokenizer + ) + + policy_cp = Policy( + cluster=cluster_cp, + config=config_cp, + tokenizer=tokenizer, + init_reference_model=False, + ) + policy_cp.prepare_for_lp_inference() + out_cp = policy_cp.get_topk_logits(data, k=k) + logits_cp = out_cp["topk_logits"] * attention_mask.unsqueeze(-1) + indices_cp = out_cp["topk_indices"] + + # Cleanup CP resources + policy_cp.shutdown() + cluster_cp.shutdown() + + # Compare CP vs non-CP (no packing) + print("=== Comparing CP vs non-CP (no packing) top-k ===") + assert logits_no_cp_np.shape == logits_cp.shape + assert indices_no_cp_np.shape == indices_cp.shape + assert not torch.isnan(logits_cp).any() + assert not torch.isinf(logits_cp).any() + torch.testing.assert_close(logits_no_cp_np, logits_cp, rtol=1e-3, atol=1e-2) + # since there are close logits, we only check the index match ratio + valid_mask_idx = ( + attention_mask.bool().unsqueeze(-1).expand(-1, -1, indices_cp.shape[-1]) + ) + cp_idx_flat = indices_cp[valid_mask_idx] + nocp_idx_flat = indices_no_cp_np[valid_mask_idx] + match_ratio = (cp_idx_flat == nocp_idx_flat).float().mean().item() + print(f"Top-k index match ratio (CP vs non-CP): {match_ratio:.4f}") + assert match_ratio >= 0.95, ( + f"Top-k index match ratio too low: {match_ratio:.4f} (< 0.95)" + ) + + @pytest.mark.timeout(300) @pytest.mark.hf_gated def test_megatron_sft_training(tiny_llama_model_path): diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 592af19035..48f44b8349 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -154,7 +154,7 @@ def test_all_recipe_yamls_accounted_for_in_test_suites( ) -def test_nightly_compute_stays_below_1030_hours(nightly_test_suite, tracker): +def test_nightly_compute_stays_below_1040_hours(nightly_test_suite, tracker): command = f"DRYRUN=1 HF_HOME=... HF_DATASETS_CACHE=... CONTAINER= ACCOUNT= PARTITION= ./tools/launch {' '.join(nightly_test_suite)}" print(f"Running command: {command}") @@ -186,8 +186,8 @@ def test_nightly_compute_stays_below_1030_hours(nightly_test_suite, tracker): f"Last line of output was not as expected: '{last_line}'" ) total_gpu_hours = float(last_line.split(":")[-1].strip()) - assert total_gpu_hours <= 1030, ( - f"Total GPU hours exceeded 1030: {last_line}. We should revisit the test suites to reduce the total GPU hours." + assert total_gpu_hours <= 1040, ( + f"Total GPU hours exceeded 1040: {last_line}. We should revisit the test suites to reduce the total GPU hours." ) tracker.track("total_nightly_gpu_hours", total_gpu_hours)