Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true
logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ grpo:
val_at_start: false
max_val_samples: 480
val_batch_size: 32
seed: 42

loss_fn:
reference_policy_kl_penalty: 0.0
Expand Down Expand Up @@ -118,6 +119,7 @@ data:
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
dataset_name: "DeepScaler"
shuffle: true

env:
math:
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down Expand Up @@ -127,6 +128,7 @@ data:
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
dataset_name: "OpenMathInstruct-2"
shuffle: true

env:
math:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ data:
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
dataset_name: "OpenMathInstruct-2"
shuffle: true

env:
math:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_sliding_puzzle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ policy:

data:
add_system_prompt: false
shuffle: false # disable dataloader shuffle, shuffle is handled within the dataset

env:
sliding_puzzle_game:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true

logger:
log_dir: "logs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true

logger:
log_dir: "logs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true

logger:
log_dir: "logs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true

logger:
log_dir: "logs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ policy:
data:
dataset_name: "HelpSteer3"
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true

logger:
log_dir: "logs"
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -104,6 +105,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down Expand Up @@ -105,6 +106,7 @@ data:
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
logger:
log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
logger:
log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ data:
dataset_name: squad
add_bos: true
add_eos: true
shuffle: true
logger:
log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp1
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
logger:
log_dir: logs/sft-llama3.2-1b-1n8g-fsdp2tp1
wandb_enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true
logger:
log_dir: logs/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt
wandb_enabled: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ policy:
data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "HelpSteer3"
shuffle: true

logger:
log_dir: "logs" # Base directory for all logs
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ data:
add_bos: true
add_eos: true
add_generation_prompt: false
shuffle: true

logger:
log_dir: "logs" # Base directory for all logs
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ data:
add_eos: true
add_generation_prompt: true
output_key: 'generated_solution'
shuffle: true

logger:
log_dir: "logs" # Base directory for all logs
Expand Down
7 changes: 4 additions & 3 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def setup_data(
tokenizer: TokenizerType,
data_config: DataConfig,
env_configs: dict[str, Any],
seed: int,
) -> tuple[
AllTaskProcessedDataset,
Optional[AllTaskProcessedDataset],
Expand All @@ -140,12 +141,12 @@ def setup_data(
# Load OpenMathInstruct2Dataset using nemo rl datasets
if data_config["dataset_name"] == "OpenMathInstruct-2":
print("Loading nvidia/OpenMathInstruct2Dataset for training and validation")
data: Any = OpenMathInstruct2Dataset()
data: Any = OpenMathInstruct2Dataset(seed=seed)
elif data_config["dataset_name"] == "DeepScaler":
print(
"Loading agentica-org/DeepScaleR-Preview-Dataset for training and validation"
)
data: Any = DeepScalerDataset()
data: Any = DeepScalerDataset(seed=seed)
else:
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")

Expand Down Expand Up @@ -236,7 +237,7 @@ def main() -> None:
val_dataset,
task_to_env,
val_task_to_env,
) = setup_data(tokenizer, config["data"], config["env"])
) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"])

(
policy,
Expand Down
4 changes: 3 additions & 1 deletion examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import AutoTokenizer

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.algorithms.utils import get_tokenizer, set_seed
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.games.sliding_puzzle import (
Expand Down Expand Up @@ -223,6 +223,8 @@ def main():

init_ray()

set_seed(config["grpo"]["seed"])

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
Expand Down
5 changes: 4 additions & 1 deletion examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
print("\n▶ Setting up data...")
data_cls = data_config["dataset_name"]
if data_cls == "open_assistant":
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
data = hf_datasets.OasstDataset(
output_dir="/tmp/open_assistant", seed=data_config["seed"]
)
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
elif data_cls == "prompt_response_dataset":
Expand All @@ -110,6 +112,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
split=data_config["split"],
output_key=data_config["output_key"],
prompt_file=data_config["prompt_file"],
seed=data_config["seed"],
)
elif data_cls == "openai_format":
data = hf_datasets.OpenAIFormatDataset(
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def setup(
train_dataloader = StatefulDataLoader(
train_dataset,
batch_size=policy_config["train_global_batch_size"],
shuffle=True,
shuffle=data_config["shuffle"],
collate_fn=partial(
dpo_collate_fn,
tokenizer=tokenizer,
Expand Down
Loading