Skip to content
Closed
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ec7a4f6
adding an input template for formatting input
wasiahmad Oct 3, 2025
83dd6f2
adding an input template for formatting input
wasiahmad Oct 3, 2025
a08fcdd
minor bug fix
wasiahmad Oct 3, 2025
545f09e
replacing comma with semicolon to avoid hydra issues
wasiahmad Oct 3, 2025
c010c57
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 4, 2025
62164ed
using nemo-skills load config
wasiahmad Oct 4, 2025
081cd0e
Merge branch 'main' into apply_input_template
wasiahmad Oct 6, 2025
1260e81
Merge branch 'main' into apply_input_template
wasiahmad Oct 14, 2025
e4d35c5
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 15, 2025
ef650c6
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 17, 2025
2353efe
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 17, 2025
e0e9009
resolving conflicts
wasiahmad Oct 24, 2025
b6e3f46
Merge branch 'main' into apply_input_template
wasiahmad Oct 27, 2025
0863b47
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 28, 2025
7b610ff
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 28, 2025
3c9970e
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Oct 29, 2025
ef68fe7
Merge branch 'main' into apply_input_template
wasiahmad Dec 8, 2025
0d655ed
Merge remote-tracking branch 'origin/main' into apply_input_template
wasiahmad Jan 30, 2026
e992227
HF ASR Leaderboard Fix (#1140)
melllinia Jan 15, 2026
8d6766d
Merge branch 'main' into apply_input_template
wasiahmad Mar 3, 2026
66af4d5
Merge branch 'main' into apply_input_template
wasiahmad Mar 3, 2026
ee7cde3
Merge branch 'main' into apply_input_template
wasiahmad Apr 17, 2026
33359af
Pass tools from datum into SFT preprocessor formatting
wasiahmad Apr 17, 2026
7559512
SFT debug: print first five messages in message log
wasiahmad Apr 17, 2026
8be7bf1
fixing chat template apply issue with tool calls
wasiahmad Apr 21, 2026
e8a191e
Merge branch 'main' into apply_input_template
wasiahmad Apr 21, 2026
65f604e
Align asr-leaderboard dataset __init__ with main
wasiahmad Apr 21, 2026
0b1a79c
Fix PromptResponseDataset cache invalidation for SFT
wasiahmad Apr 21, 2026
6700726
adding moe_aux_loss_coeff=0.0 to bypass MoE aux loss validation asser…
wasiahmad Apr 27, 2026
f228e51
adding moe_aux_loss_coeff=0.0 to bypass MoE aux loss validation asser…
wasiahmad Apr 27, 2026
c702bc9
Merge branch 'main' into apply_input_template
wasiahmad May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions nemo_skills/training/nemo_rl/start_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pathlib import Path
from typing import Any, Dict, Optional

import yaml
from datasets import Dataset, load_dataset, load_from_disk
from nemo_rl.algorithms.sft import MasterConfig, setup, sft_train
from nemo_rl.algorithms.utils import get_tokenizer
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
output_key: str = "output",
num_proc: int | None = None,
force_reprocess: bool = False,
input_template_path: str | None = None,
Comment thread
wasiahmad marked this conversation as resolved.
):
self.input_key = input_key
self.output_key = output_key
Expand All @@ -98,6 +100,14 @@ def __init__(
else:
self.num_proc = num_proc

self.input_template = None
if input_template_path:
with open(input_template_path, "rt", encoding="utf-8") as fin:
Comment thread
wasiahmad marked this conversation as resolved.
Outdated
data = yaml.safe_load(fin)
if "user" not in data:
raise KeyError(f"'user' key is missing in the YAML file: {input_template_path}")
self.input_template = data["user"]

# Train split
self.formatted_ds = {
"train": self.load_or_process_split(train_ds_path, "train"),
Expand Down Expand Up @@ -128,11 +138,22 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset:
print(f"[Map] Processing {split_name} dataset from: {path}")
dataset = load_dataset("json", data_files=str(path))["train"]

current_input_key = self.input_key
Comment thread
wasiahmad marked this conversation as resolved.
if self.input_template:
assert "messages" not in dataset.column_names
Comment thread
wasiahmad marked this conversation as resolved.
dataset = dataset.map(
self.apply_input_template,
batched=True,
num_proc=self.num_proc,
)
current_input_key = "formatted_input"

if "messages" not in dataset.column_names:
dataset = dataset.map(
self.add_messages_key,
batched=True,
num_proc=self.num_proc,
fn_kwargs={"input_key": current_input_key},
)

# Save dataset + new size signature
Expand All @@ -144,17 +165,26 @@ def load_or_process_split(self, path: str, split_name: str) -> Dataset:
print(f"[Cache] Saved {split_name} dataset to: {cache_dir}")
return dataset

def add_messages_key(self, examples: dict[str, list[Any]]) -> dict[str, list[list[dict[str, Any]]]]:
def add_messages_key(
self, examples: dict[str, list[Any]], input_key: str
) -> dict[str, list[list[dict[str, Any]]]]:
return {
"messages": [
[
{"role": "user", "content": input_},
{"role": "assistant", "content": output},
]
for input_, output in zip(examples[self.input_key], examples[self.output_key])
for input_, output in zip(examples[input_key], examples[self.output_key])
Comment thread
wasiahmad marked this conversation as resolved.
]
}

def apply_input_template(self, examples: dict[str, list[Any]]) -> dict[str, list[str]]:
keys = [k.strip() for k in self.input_key.split(",")]
examples["formatted_input"] = [
self.input_template.format(**{k: examples[k][i] for k in keys}) for i in range(len(examples[keys[0]]))
Comment thread
wasiahmad marked this conversation as resolved.
]
return examples
Comment thread
wasiahmad marked this conversation as resolved.
Comment thread
wasiahmad marked this conversation as resolved.


def parse_args():
"""Parse command line arguments."""
Expand Down Expand Up @@ -233,6 +263,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
data_config["input_key"],
data_config["output_key"],
force_reprocess=data_config.get("force_reprocess", False),
input_template_path=data_config.get("input_template_path", None),
)
print(f" ✓ Training dataset loaded with {len(data.formatted_ds['train'])} samples.")
if data.formatted_ds["validation"] is not None:
Expand Down