Skip to content
Merged
20 changes: 15 additions & 5 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
# SFT Algorithm Configuration
sft:
num_steps: 100
#val_period: 10
#val_at_start: true
#checkpoint_dir: "results/sft"
num_steps: 20
val_period: 10
val_batches: 8
val_global_batch_size: 32
val_micro_batch_size: 2
val_at_start: true

checkpointing:
enabled: true
checkpoint_dir: "results/sft"
metric_name: "val_loss"
higher_is_better: false
keep_top_k: 3
save_period: 10

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
train_global_batch_size: 8
train_global_batch_size: 32
train_micro_batch_size: 2
learning_rate: 5.0e-6
max_total_sequence_length: 1024
Expand Down
110 changes: 100 additions & 10 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
import argparse
import os
import pprint
from typing import Dict, Any

from omegaconf import OmegaConf

from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.utils.config import load_config
from nemo_reinforcer.utils.logger import get_next_experiment_dir
from nemo_reinforcer.data import DataConfig, hf_datasets
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec
from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log
from transformers import AutoTokenizer
from nemo_reinforcer.models.policy import PolicyConfig


def parse_args():
Expand All @@ -39,6 +47,80 @@ def parse_args():
return args, overrides


# =======================================================
# Data Processing
# =======================================================
def sft_preprocessor(
datum_dict: Dict[str, Any],
task_data_spec: TaskDataSpec,
tokenizer,
max_seq_length: int,
idx: int,
) -> DatumSpec:
"""Process a datum dictionary for SFT training."""
message_log = get_formatted_message_log(
datum_dict["messages"], tokenizer, task_data_spec
)

length = sum(len(m["token_ids"]) for m in message_log)

loss_multiplier = 1.0
if length > max_seq_length:
# make smaller and mask out
for message in message_log:
message["token_ids"] = message["token_ids"][
: min(4, max_seq_length // len(message_log))
]
loss_multiplier = 0.0

output = {
"message_log": message_log,
"length": length,
"extra_env_info": None,
"loss_multiplier": loss_multiplier,
"idx": idx,
}
return output


def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
print("\n▶ Setting up data...")
data_cls = data_config["dataset_name"]
if data_cls == "open_assistant":
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
else:
raise ValueError(f"Unknown dataset class: {data_cls}")
print(
f" ✓ Training and validation datasets loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively."
)

train_dataset = data.formatted_ds["train"]
val_dataset = data.formatted_ds["validation"]
sft_task_spec = data.task_spec

tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])

train_dataset = AllTaskProcessedDataset(
train_dataset,
tokenizer,
sft_task_spec,
sft_preprocessor,
max_seq_length=data_config["max_input_seq_length"],
)

val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
sft_task_spec,
sft_preprocessor,
max_seq_length=data_config["max_input_seq_length"],
)

return train_dataset, val_dataset, tokenizer, sft_task_spec


def main():
"""Main entry point."""
# Parse arguments
Expand All @@ -47,13 +129,12 @@ def main():
if not args.config:
args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml")

config = OmegaConf.load(args.config)
config = load_config(args.config)
print(f"Loaded configuration from: {args.config}")

if overrides:
override_conf = OmegaConf.from_cli()
print(f"Overrides: {override_conf}")
config = OmegaConf.merge(config, override_conf)
print(f"Overrides: {overrides}")
config = OmegaConf.merge(config, overrides)

config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")
Expand All @@ -66,24 +147,33 @@ def main():
print(f"📊 Using log directory: {config['logger']['log_dir']}")

init_ray()

# setup data
dataset, val_dataset, tokenizer, sft_task_spec = setup_data(
config["data"], config["policy"]
)
(
policy,
cluster,
dataloader,
tokenizer,
train_dataloader,
val_dataloader,
loss_fn,
master_config,
logger,
sft_task_spec,
) = setup(config)
checkpointer,
sft_save_state,
master_config,
) = setup(config, dataset, val_dataset)
sft_train(
policy,
dataloader,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
sft_task_spec,
checkpointer,
sft_save_state,
)


Expand Down
Loading