From 58c5540ea8945838842814c3b206c18c8be0e78a Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 28 Aug 2025 18:07:03 -0700 Subject: [PATCH] Flush stdout. Signed-off-by: Peter Jin --- nemo_rl/algorithms/grpo.py | 86 +++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 90454bbab9..83d80bb6ca 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -196,7 +196,7 @@ def setup( ) dataloader.load_state_dict(dataloader_state_dict) - print(f" ✓ Training dataloader loaded with {len(dataset)} samples") + print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None @@ -211,12 +211,15 @@ def setup( shuffle=False, collate_fn=rl_collate_fn, ) - print(f" ✓ Validation dataloader loaded with {len(val_dataset)} samples") + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) # ========================== # Cluster # ========================== - print("\n▶ Setting up compute cluster...") + print("\n▶ Setting up compute cluster...", flush=True) colocated_inference = generation_config["colocated"]["enabled"] if colocated_inference: @@ -232,7 +235,10 @@ def setup( ) train_cluster = cluster inference_cluster = cluster - print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + print( + f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes", + flush=True, + ) else: assert generation_config["backend"] != "megatron", ( @@ -288,7 +294,8 @@ def setup( max_colocated_worker_groups=1, ) print( - f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node" + f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node", + flush=True, ) # initialize inference cluster @@ -300,13 +307,14 @@ def setup( max_colocated_worker_groups=1, ) print( - f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node" + f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node", + flush=True, ) # ========================== # Training and Inference # ========================== - print("\n▶ Setting up model and training...") + print("\n▶ Setting up model and training...", flush=True) # vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this) backend = generation_config["backend"] @@ -315,7 +323,8 @@ def setup( if backend == "megatron": policy_generation = None print( - f" ✓ Using {backend} backend for generation with {policy_config['model_name']}" + f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", + flush=True, ) elif backend == "vllm": generation_config = cast(VllmConfig, generation_config) @@ -331,7 +340,8 @@ def setup( # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory). policy_generation.finish_generation() print( - f" ✓ Using vLLM backend for generation with {policy_config['model_name']}" + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", + flush=True, ) if last_checkpoint_path: @@ -354,7 +364,7 @@ def setup( # if it is not colocated inference, initialize collective communication for update weights if not colocated_inference: ip, port = train_cluster.get_master_address_and_port() - print(f"Using ip: {ip}, port: {port} for collective communication") + print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) # inference cluster + head node of the train cluster world_size = inference_nodes * inference_gpus_per_node + 1 # init collective @@ -371,7 +381,7 @@ def setup( print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") - print("=" * 60 + "\n") + print("=" * 60 + "\n", flush=True) return ( policy, @@ -445,7 +455,8 @@ def refit_policy_generation( ) total_num_keys = sum(len(k) for k in grouped_param_keys) print( - f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups" + f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups", + flush=True, ) # do update for keys in grouped_param_keys: @@ -524,7 +535,7 @@ def grpo_train( # Run validation at the start if configured if val_at_start and step == 0: - print("\n🔍 Running initial validation...") + print("\n🔍 Running initial validation...", flush=True) if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation(policy, policy_generation, colocated_inference) POLICY_GENERATION_STALE = False @@ -546,7 +557,8 @@ def grpo_train( batch: BatchedDataDict[DatumSpec] for batch in dataloader: print( - f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}" + f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}", + flush=True, ) maybe_gpu_profile_step(policy, step + 1) if policy != policy_generation: @@ -555,7 +567,7 @@ def grpo_train( with timer.time("total_step_time"): # Prepare batch - print("▶ Preparing batch...") + print("▶ Preparing batch...", flush=True) with timer.time("data_processing"): # Repeat batch items repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave( @@ -569,7 +581,10 @@ def grpo_train( input_ids = batched_flat["token_ids"] # Generate responses - this updates the LLMMessageLogType in repeated_batch - print(f"▶ Generating responses for batch of size {repeated_batch.size}...") + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( @@ -611,12 +626,12 @@ def grpo_train( policy_generation.finish_generation() # Calculate rewards & advantages - print("▶ Processing rewards...") + print("▶ Processing rewards...", flush=True) with timer.time("reward_calculation"): # Extract rewards from final_batch rewards = repeated_batch["total_reward"] - print("▶ Computing advantages...") + print("▶ Computing advantages...", flush=True) baseline, std = calculate_baseline_and_std_per_prompt( input_ids, rewards, @@ -678,11 +693,11 @@ def grpo_train( train_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) train_data.to("cpu") - print("▶ Preparing for logprob inference...") + print("▶ Preparing for logprob inference...", flush=True) with timer.time("logprob_inference_prep"): policy.prepare_for_lp_inference() - print("▶ Computing logprobs...") + print("▶ Computing logprobs...", flush=True) with timer.time("policy_and_reference_logprobs"): fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] reference_logprobs = policy.get_reference_policy_logprobs(train_data)[ @@ -691,12 +706,12 @@ def grpo_train( train_data["prev_logprobs"] = fprop_logprobs train_data["reference_policy_logprobs"] = reference_logprobs - print("▶ Preparing for training...") + print("▶ Preparing for training...", flush=True) with timer.time("training_prep"): policy.prepare_for_training() # set model train and reload optim to GPU POLICY_GENERATION_STALE = True - print("▶ Training policy...") + print("▶ Training policy...", flush=True) with timer.time("policy_training"): train_results = policy.train(train_data, loss_fn) @@ -763,7 +778,7 @@ def grpo_train( master_config["checkpointing"]["metric_name"] = None with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") + print(f"Saving checkpoint for step {step + 1}...", flush=True) checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, grpo_save_state, master_config ) @@ -834,7 +849,8 @@ def grpo_train( print(f" • Loss: {metrics['loss']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print( - f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}" + f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}", + flush=True, ) if "total_flops" in train_results: total_tflops = ( @@ -842,16 +858,18 @@ def grpo_train( ) num_ranks = train_results["num_ranks"] print( - f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + flush=True, ) if "theoretical_tflops" in train_results: theoretical_tflops = train_results["theoretical_tflops"] print( - f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, ) metrics["train_fp_utilization"] = total_tflops / theoretical_tflops - print("\n⏱️ Timing:") + print("\n⏱️ Timing:", flush=True) # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) @@ -867,7 +885,7 @@ def grpo_train( } ) - print(f" • Total step time: {total_time:.2f}s") + print(f" • Total step time: {total_time:.2f}s", flush=True) # Display all other timing metrics for k, v in sorted( @@ -875,7 +893,7 @@ def grpo_train( ): if k != "total_step_time": percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") @@ -896,12 +914,12 @@ def validate( ) -> tuple[dict[str, Any], dict[str, Any]]: """Run validation on the validation dataset.""" if val_dataloader is None: - print(" ⚠️ No validation dataloader provided, skipping validation") + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) return {}, {} timer = Timer() with timer.time("total_validation_time"): - print(f"▶ Starting validation at step {step}...") + print(f"▶ Starting validation at step {step}...", flush=True) total_rewards = [] total_lengths = [] @@ -974,7 +992,7 @@ def validate( ) except Exception as e: print(f"\n ⚠️ Error displaying message samples: {str(e)}") - print(" ⚠️ Continuing validation without displaying samples...") + print(" ⚠️ Continuing validation without displaying samples...", flush=True) # Get timing metrics timing_metrics = timer.get_timing_metrics(reduction_op="sum") @@ -984,12 +1002,12 @@ def validate( print("\n📊 Validation Results:") print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") - print(f" • Samples processed: {len(total_rewards)}") + print(f" • Samples processed: {len(total_rewards)}", flush=True) # Print timing information print("\n ⏱️ Validation Timing:") validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s") + print(f" • Total validation time: {validation_time:.2f}s", flush=True) # Make sure to reset the timer after validation timer.reset()