diff --git a/examples/run_sft.py b/examples/run_sft.py index bcda89e09c..8f65262c73 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -107,7 +107,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): val_dataset = data.formatted_ds["validation"] sft_task_spec = data.task_spec print( - f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset)} samples, respectively." + f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset) if val_dataset else 0} samples, respectively." ) # add preprocessor if needed @@ -133,19 +133,20 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): max_seq_length=data_config["max_input_seq_length"], ) - val_dataset = AllTaskProcessedDataset( - val_dataset, - tokenizer, - sft_task_spec, - partial( - sft_preprocessor, - add_bos=data_config.get("add_bos", True), - add_eos=data_config.get("add_eos", True), - add_generation_prompt=data_config["add_generation_prompt"], - datum_preprocessor=datum_preprocessor, - ), - max_seq_length=data_config["max_input_seq_length"], - ) + if val_dataset is not None: + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + sft_task_spec, + partial( + sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + add_generation_prompt=data_config["add_generation_prompt"], + datum_preprocessor=datum_preprocessor, + ), + max_seq_length=data_config["max_input_seq_length"], + ) return train_dataset, val_dataset, sft_task_spec diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index ac44521ef7..09cbdf93c2 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -90,12 +90,12 @@ def setup( master_config: MasterConfig, tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, - val_dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset], ) -> tuple[ Policy, RayVirtualCluster, StatefulDataLoader, - StatefulDataLoader, + Optional[StatefulDataLoader], NLLLoss, Logger, CheckpointManager, @@ -149,14 +149,17 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=sft_config["val_global_batch_size"], - shuffle=False, - collate_fn=rl_collate_fn, - drop_last=False, - num_workers=data_config["num_workers"], - ) + if val_dataset is not None: + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=sft_config["val_global_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + drop_last=False, + num_workers=data_config["num_workers"], + ) + else: + val_dataloader = None # ========================== # Cluster @@ -230,7 +233,7 @@ def setup( # ======================================================= def validate( policy: PolicyInterface, - val_dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], tokenizer, loss_fn, step: int, @@ -242,11 +245,11 @@ def validate( ): """Run validation on the validation dataset.""" if val_dataloader is None: - assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( - "val_dataloader is None, so dpo.val_period must be 0" + assert master_config["sft"]["val_period"] <= 0, ( + "val_dataloader is None, so sft.val_period must be <= 0" ) print(" ⚠️ No validation dataloader provided, skipping validation") - return + return {}, {} timer = Timer() @@ -496,7 +499,7 @@ def sft_train( metrics[k] = np.mean(v).item() else: metrics[k] = np.sum(v).item() - total_valid_tokens += metrics["global_valid_toks"] + total_valid_tokens += metrics.get("global_valid_toks", 0) ## Checkpointing sft_save_state["consumed_samples"] += master_config["policy"][ @@ -610,9 +613,12 @@ def sft_train( master_config["cluster"]["num_nodes"] * master_config["cluster"]["gpus_per_node"] ) - timing_metrics["valid_tokens_per_sec_per_gpu"] = ( - metrics["global_valid_toks"] / total_time / total_num_gpus - ) + if total_time > 0: + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics.get("global_valid_toks", 0) / total_time / total_num_gpus + ) + else: + timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0 logger.log_metrics(metrics, total_steps + 1, prefix="train") logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index e43630651e..83d0cf20c4 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -205,3 +205,51 @@ def test_exit_on_timeout(mock_components, capsys): assert "Epoch" not in line or "Epoch 1/10" in line, ( f"Training continued to next epoch after timeout: {line}" ) + + +def test_training_with_disabled_validation(mock_components): + """Test that training works when validation is disabled (val_dataloader=None, val_period<=0)""" + mock_components["master_config"]["sft"]["val_period"] = 0 + mock_components["master_config"]["sft"]["max_num_steps"] = 5 + mock_components["master_config"]["sft"]["max_num_epochs"] = 1 + + sft_save_state = _default_sft_save_state() + + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + None, # val_dataloader is None + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + assert mock_components["policy"].train.call_count == 5 + + +def test_training_with_negative_val_period(mock_components): + """Test that training works when val_period is negative (validation disabled)""" + mock_components["master_config"]["sft"]["val_period"] = -1 + mock_components["master_config"]["sft"]["max_num_steps"] = 3 + mock_components["master_config"]["sft"]["max_num_epochs"] = 1 + + sft_save_state = _default_sft_save_state() + + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + None, # val_dataloader is None + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + assert mock_components["policy"].train.call_count == 3