Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SFT Algorithm Configuration
sft:
## total number of steps to train will equal
## min((max_num_epochs * len(train_dataloader)), max_num_steps)
max_num_epochs: 1
max_num_steps: 60

val_period: 10
val_batches: 8
val_global_batch_size: 32
Expand Down
281 changes: 157 additions & 124 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,25 @@


class SFTSaveState(TypedDict):
step: int
epoch: int # Track current epoch
step: int # Track step within current epoch
total_steps: int # Track total number of steps across all epochs
val_loss: float
consumed_samples: int


def _default_sft_save_state() -> SFTSaveState:
return {
"epoch": 0,
"step": 0,
"total_steps": 0,
"consumed_samples": 0,
}


class SFTConfig(TypedDict):
max_num_steps: int
max_num_epochs: int
val_period: int
val_batches: int
val_global_batch_size: int
Expand Down Expand Up @@ -141,6 +146,7 @@ def setup(
batch_size=policy_config["train_global_batch_size"],
shuffle=True,
collate_fn=rl_collate_fn,
drop_last=True,
)

if last_checkpoint_path is not None:
Expand Down Expand Up @@ -333,17 +339,22 @@ def sft_train(

if sft_save_state is None:
sft_save_state = _default_sft_save_state()
step = 0
current_epoch = 0
current_step = 0
total_steps = 0
else:
step = sft_save_state["step"]
current_epoch = sft_save_state["epoch"]
current_step = sft_save_state["step"]
total_steps = sft_save_state["total_steps"]

sft_config = master_config["sft"]
# Validation configuration
val_period = sft_config["val_period"]
val_at_start = sft_config["val_at_start"]
max_num_epochs = sft_config["max_num_epochs"]

# Run validation at the start if configured
if val_at_start and step == 0:
if val_at_start and total_steps == 0:
print("\n🔍 Running initial validation...")
val_metrics, validation_timings = validate(
policy,
Expand All @@ -358,134 +369,156 @@ def sft_train(
val_mbs=sft_config["val_micro_batch_size"],
)

logger.log_metrics(val_metrics, step, prefix="validation")
logger.log_metrics(validation_timings, step, prefix="timing/validation")
logger.log_metrics(val_metrics, total_steps, prefix="validation")
logger.log_metrics(validation_timings, total_steps, prefix="timing/validation")

policy.prepare_for_training()

for batch in train_dataloader:
print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}")

with timer.time("total_step_time"):
# Prepare batch and generate responses
print("▶ Preparing batch...")
with timer.time("data_processing"):
## add loss mask based on role to every message
add_loss_mask_to_message_log(
batch["message_log"],
roles_to_train_on=["assistant"],
)
while (
current_epoch < max_num_epochs
and total_steps < master_config["sft"]["max_num_steps"]
):
print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}")

cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)
for batch in train_dataloader:
print(
f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['sft']['max_num_steps'])} {'=' * 25}"
)

train_data: BatchedDataDict = BatchedDataDict(
{
"input_ids": cat_and_padded["token_ids"],
"input_lengths": input_lengths,
"token_mask": cat_and_padded["token_loss_mask"],
"sample_mask": batch["loss_multiplier"],
}
)
with timer.time("total_step_time"):
# Prepare batch and generate responses
print("▶ Preparing batch...")
with timer.time("data_processing"):
## add loss mask based on role to every message
add_loss_mask_to_message_log(
batch["message_log"],
roles_to_train_on=["assistant"],
)

## train_data.to("cpu")
print("▶ Taking a training step...")
train_results = policy.train(train_data, loss_fn)

# Run validation if it's a validation step
if val_period > 0 and (step + 1) % val_period == 0:
val_metrics, validation_timings = validate(
policy,
val_dataloader,
tokenizer,
loss_fn,
step=step + 1,
master_config=master_config,
sft_task_spec=sft_task_spec,
val_batches=sft_config["val_batches"],
val_batch_size=sft_config["val_global_batch_size"],
val_mbs=sft_config["val_micro_batch_size"],
)
logger.log_metrics(
validation_timings, step + 1, prefix="timing/validation"
)
logger.log_metrics(val_metrics, step + 1, prefix="validation")

## Checkpointing
sft_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if (
master_config["checkpointing"]["enabled"]
and (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(len(train_dataloader), master_config["sft"]["max_num_steps"])
- (step + 1)
< master_config["checkpointing"]["save_period"]
)
cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)

sft_save_state["step"] = step + 1
sft_save_state["val_loss"] = val_metrics["val_loss"]
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
step + 1, sft_save_state, master_config
train_data: BatchedDataDict = BatchedDataDict(
{
"input_ids": cat_and_padded["token_ids"],
"input_lengths": input_lengths,
"token_mask": cat_and_padded["token_loss_mask"],
"sample_mask": batch["loss_multiplier"],
}
)

policy.save_checkpoint(
weights_path=os.path.join(checkpoint_path, "policy", "weights"),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
print("▶ Taking a training step...")
train_results = policy.train(train_data, loss_fn)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
val_metrics, validation_timings = validate(
policy,
val_dataloader,
tokenizer,
loss_fn,
step=total_steps + 1,
master_config=master_config,
sft_task_spec=sft_task_spec,
val_batches=sft_config["val_batches"],
val_batch_size=sft_config["val_global_batch_size"],
val_mbs=sft_config["val_micro_batch_size"],
)
torch.save(
train_dataloader.state_dict(),
os.path.join(checkpoint_path, "train_dataloader.pt"),
logger.log_metrics(
validation_timings, total_steps + 1, prefix="timing/validation"
)
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
metrics = {
"loss": train_results["loss"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k == "num_valid_samples":
metrics[k] = np.sum(v).item()
else:
metrics[k] = np.mean(v).item()
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
print(f" • Total step time: {total_time:.2f}s")

# Display all other timing metrics (if any)
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")

timer.reset()
step += 1

if step >= master_config["sft"]["max_num_steps"]:
break
logger.log_metrics(
val_metrics, total_steps + 1, prefix="validation"
)

## Checkpointing
sft_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if (
master_config["checkpointing"]["enabled"]
and (total_steps + 1)
% master_config["checkpointing"]["save_period"]
== 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(
len(train_dataloader) * max_num_epochs,
master_config["sft"]["max_num_steps"],
)
- (total_steps + 1)
< master_config["checkpointing"]["save_period"]
)

sft_save_state["step"] = (current_step + 1) % len(train_dataloader)
sft_save_state["total_steps"] = total_steps + 1
sft_save_state["epoch"] = current_epoch
sft_save_state["val_loss"] = val_metrics["val_loss"]
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
total_steps + 1, sft_save_state, master_config
)

policy.save_checkpoint(
weights_path=os.path.join(
checkpoint_path, "policy", "weights"
),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
train_dataloader.state_dict(),
os.path.join(checkpoint_path, "train_dataloader.pt"),
)
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
metrics = {
"loss": train_results["loss"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k == "num_valid_samples":
metrics[k] = np.sum(v).item()
else:
metrics[k] = np.mean(v).item()
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
print(f" • Total step time: {total_time:.2f}s")

# Display all other timing metrics (if any)
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

logger.log_metrics(metrics, total_steps + 1, prefix="train")
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")

timer.reset()
current_step += 1
total_steps += 1

if total_steps >= master_config["sft"]["max_num_steps"]:
return

current_epoch += 1
current_step = 0 # Reset step counter for new epoch
Loading