Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int):
val_dataset = data.formatted_ds["validation"]
sft_task_spec = data.task_spec
print(
f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset)} samples, respectively."
f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset) if val_dataset else 0} samples, respectively."
)

# add preprocessor if needed
Expand All @@ -133,19 +133,20 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int):
max_seq_length=data_config["max_input_seq_length"],
)

val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
sft_task_spec,
partial(
sft_preprocessor,
add_bos=data_config.get("add_bos", True),
add_eos=data_config.get("add_eos", True),
add_generation_prompt=data_config["add_generation_prompt"],
datum_preprocessor=datum_preprocessor,
),
max_seq_length=data_config["max_input_seq_length"],
)
if val_dataset is not None:
val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
sft_task_spec,
partial(
sft_preprocessor,
add_bos=data_config.get("add_bos", True),
add_eos=data_config.get("add_eos", True),
add_generation_prompt=data_config["add_generation_prompt"],
datum_preprocessor=datum_preprocessor,
),
max_seq_length=data_config["max_input_seq_length"],
)

return train_dataset, val_dataset, sft_task_spec

Expand Down
42 changes: 24 additions & 18 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
train_dataset: AllTaskProcessedDataset,
val_dataset: AllTaskProcessedDataset,
val_dataset: Optional[AllTaskProcessedDataset],
) -> tuple[
Policy,
RayVirtualCluster,
StatefulDataLoader,
StatefulDataLoader,
Optional[StatefulDataLoader],
NLLLoss,
Logger,
CheckpointManager,
Expand Down Expand Up @@ -149,14 +149,17 @@ def setup(
)
train_dataloader.load_state_dict(dataloader_state_dict)

val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=sft_config["val_global_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
drop_last=False,
num_workers=data_config["num_workers"],
)
if val_dataset is not None:
val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=sft_config["val_global_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
drop_last=False,
num_workers=data_config["num_workers"],
)
else:
val_dataloader = None

# ==========================
# Cluster
Expand Down Expand Up @@ -230,7 +233,7 @@ def setup(
# =======================================================
def validate(
policy: PolicyInterface,
val_dataloader: StatefulDataLoader,
val_dataloader: Optional[StatefulDataLoader],
tokenizer,
loss_fn,
step: int,
Expand All @@ -242,11 +245,11 @@ def validate(
):
"""Run validation on the validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
assert master_config["sft"]["val_period"] <= 0, (
"val_dataloader is None, so sft.val_period must be <= 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation")
return
return {}, {}

timer = Timer()

Expand Down Expand Up @@ -496,7 +499,7 @@ def sft_train(
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
total_valid_tokens += metrics["global_valid_toks"]
total_valid_tokens += metrics.get("global_valid_toks", 0)

## Checkpointing
sft_save_state["consumed_samples"] += master_config["policy"][
Expand Down Expand Up @@ -610,9 +613,12 @@ def sft_train(
master_config["cluster"]["num_nodes"]
* master_config["cluster"]["gpus_per_node"]
)
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics["global_valid_toks"] / total_time / total_num_gpus
)
if total_time > 0:
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics.get("global_valid_toks", 0) / total_time / total_num_gpus
)
else:
timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0
logger.log_metrics(metrics, total_steps + 1, prefix="train")
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")

Expand Down
48 changes: 48 additions & 0 deletions tests/unit/algorithms/test_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,51 @@ def test_exit_on_timeout(mock_components, capsys):
assert "Epoch" not in line or "Epoch 1/10" in line, (
f"Training continued to next epoch after timeout: {line}"
)


def test_training_with_disabled_validation(mock_components):
"""Test that training works when validation is disabled (val_dataloader=None, val_period<=0)"""
mock_components["master_config"]["sft"]["val_period"] = 0
mock_components["master_config"]["sft"]["max_num_steps"] = 5
mock_components["master_config"]["sft"]["max_num_epochs"] = 1

sft_save_state = _default_sft_save_state()

sft_train(
mock_components["policy"],
mock_components["train_dataloader"],
None, # val_dataloader is None
mock_components["tokenizer"],
mock_components["loss_fn"],
mock_components["master_config"],
mock_components["logger"],
mock_components["sft_task_spec"],
mock_components["checkpointer"],
sft_save_state,
)

assert mock_components["policy"].train.call_count == 5


def test_training_with_negative_val_period(mock_components):
"""Test that training works when val_period is negative (validation disabled)"""
mock_components["master_config"]["sft"]["val_period"] = -1
mock_components["master_config"]["sft"]["max_num_steps"] = 3
mock_components["master_config"]["sft"]["max_num_epochs"] = 1

sft_save_state = _default_sft_save_state()

sft_train(
mock_components["policy"],
mock_components["train_dataloader"],
None, # val_dataloader is None
mock_components["tokenizer"],
mock_components["loss_fn"],
mock_components["master_config"],
mock_components["logger"],
mock_components["sft_task_spec"],
mock_components["checkpointer"],
sft_save_state,
)

assert mock_components["policy"].train.call_count == 3
Loading