Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SFT Algorithm Configuration
sft:
Comment thread
ashors1 marked this conversation as resolved.
## total number of steps to train will equal
## min((max_num_epochs * len(train_dataloader)), max_num_steps)
max_num_epochs: 1
max_num_steps: 60

val_period: 10
val_batches: 8
val_global_batch_size: 32
Expand Down
281 changes: 157 additions & 124 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,25 @@


class SFTSaveState(TypedDict):
step: int
epoch: int # Track current epoch
step: int # Track step within current epoch
total_steps: int # Track total number of steps across all epochs
val_loss: float
consumed_samples: int


def _default_sft_save_state() -> SFTSaveState:
return {
"epoch": 0,
"step": 0,
"total_steps": 0,
"consumed_samples": 0,
}


class SFTConfig(TypedDict):
max_num_steps: int
max_num_epochs: int
val_period: int
val_batches: int
val_global_batch_size: int
Expand Down Expand Up @@ -140,6 +145,7 @@ def setup(
batch_size=policy_config["train_global_batch_size"],
shuffle=True,
collate_fn=rl_collate_fn,
drop_last=True,
)

if last_checkpoint_path is not None:
Expand Down Expand Up @@ -316,17 +322,22 @@ def sft_train(

if sft_save_state is None:
sft_save_state = _default_sft_save_state()
step = 0
current_epoch = 0
current_step = 0
total_steps = 0
else:
step = sft_save_state["step"]
current_epoch = sft_save_state["epoch"]
current_step = sft_save_state["step"]
total_steps = sft_save_state["total_steps"]

sft_config = master_config["sft"]
# Validation configuration
val_period = sft_config["val_period"]
val_at_start = sft_config["val_at_start"]
max_num_epochs = sft_config["max_num_epochs"]

# Run validation at the start if configured
if val_at_start and step == 0:
if val_at_start and total_steps == 0:
print("\n🔍 Running initial validation...")
val_metrics, validation_timings = validate(
policy,
Expand All @@ -341,130 +352,152 @@ def sft_train(
val_mbs=sft_config["val_micro_batch_size"],
)

logger.log_metrics(val_metrics, step, prefix="validation")
logger.log_metrics(validation_timings, step, prefix="timing/validation")
logger.log_metrics(val_metrics, total_steps, prefix="validation")
logger.log_metrics(validation_timings, total_steps, prefix="timing/validation")

policy.prepare_for_training()

for batch in train_dataloader:
print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}")

with timer.time("total_step_time"):
# Prepare batch and generate responses
print("▶ Preparing batch...")
with timer.time("data_processing"):
## add loss mask based on role to every message
add_loss_mask_to_message_log(
batch["message_log"],
roles_to_train_on=["assistant"],
)

cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)

train_data: BatchedDataDict = BatchedDataDict(
{
"input_ids": cat_and_padded["token_ids"],
"input_lengths": input_lengths,
"token_mask": cat_and_padded["token_loss_mask"],
"sample_mask": batch["loss_multiplier"],
}
)

## train_data.to("cpu")
print("▶ Taking a training step...")
train_results = policy.train(train_data, loss_fn)

# Run validation if it's a validation step
if val_period > 0 and (step + 1) % val_period == 0:
val_metrics, validation_timings = validate(
policy,
val_dataloader,
tokenizer,
loss_fn,
step=step + 1,
master_config=master_config,
sft_task_spec=sft_task_spec,
val_batches=sft_config["val_batches"],
val_batch_size=sft_config["val_global_batch_size"],
val_mbs=sft_config["val_micro_batch_size"],
)
logger.log_metrics(
validation_timings, step + 1, prefix="timing/validation"
)
logger.log_metrics(val_metrics, step + 1, prefix="validation")

## Checkpointing
sft_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if (
master_config["checkpointing"]["enabled"]
and (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(len(train_dataloader), master_config["sft"]["max_num_steps"])
- (step + 1)
< master_config["checkpointing"]["save_period"]
)

sft_save_state["step"] = step + 1
sft_save_state["val_loss"] = val_metrics["val_loss"]
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
step + 1, sft_save_state, master_config
while (
current_epoch < max_num_epochs
and total_steps < master_config["sft"]["max_num_steps"]
):
print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}")

for batch in train_dataloader:
Comment thread
terrykong marked this conversation as resolved.
print(
f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['sft']['max_num_steps'])} {'=' * 25}"
)

with timer.time("total_step_time"):
# Prepare batch and generate responses
print("▶ Preparing batch...")
with timer.time("data_processing"):
## add loss mask based on role to every message
add_loss_mask_to_message_log(
batch["message_log"],
roles_to_train_on=["assistant"],
)

cat_and_padded, input_lengths = batched_message_log_to_flat_message(
batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)

policy.save_checkpoint(
weights_path=os.path.join(checkpoint_path, "policy", "weights"),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
train_data: BatchedDataDict = BatchedDataDict(
{
"input_ids": cat_and_padded["token_ids"],
"input_lengths": input_lengths,
"token_mask": cat_and_padded["token_loss_mask"],
"sample_mask": batch["loss_multiplier"],
}
)
torch.save(
train_dataloader.state_dict(),
os.path.join(checkpoint_path, "train_dataloader.pt"),

print("▶ Taking a training step...")
train_results = policy.train(train_data, loss_fn)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
val_metrics, validation_timings = validate(
policy,
val_dataloader,
tokenizer,
loss_fn,
step=total_steps + 1,
master_config=master_config,
sft_task_spec=sft_task_spec,
val_batches=sft_config["val_batches"],
val_batch_size=sft_config["val_global_batch_size"],
val_mbs=sft_config["val_micro_batch_size"],
)
logger.log_metrics(
validation_timings, total_steps + 1, prefix="timing/validation"
)
logger.log_metrics(
val_metrics, total_steps + 1, prefix="validation"
)
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
metrics = {
"loss": train_results["loss"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
metrics = {k: np.mean(v).item() for k, v in metrics.items()}
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
print(f" • Total step time: {total_time:.2f}s")

# Display all other timing metrics (if any)
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")

timer.reset()
step += 1

if step >= master_config["sft"]["max_num_steps"]:
break

## Checkpointing
sft_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if (
master_config["checkpointing"]["enabled"]
and (total_steps + 1)
% master_config["checkpointing"]["save_period"]
== 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(
len(train_dataloader) * max_num_epochs,
master_config["sft"]["max_num_steps"],
)
- (total_steps + 1)
< master_config["checkpointing"]["save_period"]
)

sft_save_state["step"] = (current_step + 1) % len(train_dataloader)
sft_save_state["total_steps"] = total_steps + 1
sft_save_state["epoch"] = current_epoch
sft_save_state["val_loss"] = val_metrics["val_loss"]
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
total_steps + 1, sft_save_state, master_config
)

policy.save_checkpoint(
weights_path=os.path.join(
checkpoint_path, "policy", "weights"
),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
train_dataloader.state_dict(),
os.path.join(checkpoint_path, "train_dataloader.pt"),
)
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
metrics = {
"loss": train_results["loss"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
metrics = {k: np.mean(v).item() for k, v in metrics.items()}
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
print(f" • Total step time: {total_time:.2f}s")

# Display all other timing metrics (if any)
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

logger.log_metrics(metrics, total_steps + 1, prefix="train")
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")

timer.reset()
current_step += 1
total_steps += 1

if total_steps >= master_config["sft"]["max_num_steps"]:
Comment thread
ashors1 marked this conversation as resolved.
return

current_epoch += 1
current_step = 0 # Reset step counter for new epoch
10 changes: 5 additions & 5 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def train(
mbs: Optional[int] = None,
):
"""Train the policy on a batch of data with a given loss function."""
batch_size = gbs or self.cfg["train_global_batch_size"]
micro_batch_size = mbs or self.cfg["train_micro_batch_size"]
# Shard and replicate the batch
shards = self.dp_size
sharded_data = data.shard_by_batch_size(
shards, batch_size=self.cfg["train_global_batch_size"]
)
sharded_data = data.shard_by_batch_size(shards, batch_size=batch_size)

# Train each shard in parallel
futures = self.worker_group.run_all_workers_multiple_data(
Expand All @@ -166,8 +166,8 @@ def train(
common_kwargs={
"loss_fn": loss_fn,
"eval_mode": eval_mode,
"gbs": gbs,
"mbs": mbs,
"gbs": batch_size,
"mbs": micro_batch_size,
},
only_on="all_tied_workers",
)
Expand Down
Loading