diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 4524338e4f..49a5570846 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -150,6 +150,7 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true logger: log_dir: "logs" # Base directory for all logs wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index e742480739..0593c6fce8 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -10,6 +10,7 @@ grpo: val_at_start: false max_val_samples: 480 val_batch_size: 32 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.0 @@ -118,6 +119,7 @@ data: prompt_file: "examples/prompts/cot.txt" system_prompt_file: null dataset_name: "DeepScaler" + shuffle: true env: math: diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index b9be32bdda..c580446546 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -10,6 +10,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 @@ -127,6 +128,7 @@ data: prompt_file: "examples/prompts/cot.txt" system_prompt_file: null dataset_name: "OpenMathInstruct-2" + shuffle: true env: math: diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index cf6ba44d75..feddff5a43 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -146,6 +146,7 @@ data: prompt_file: "examples/prompts/cot.txt" system_prompt_file: null dataset_name: "OpenMathInstruct-2" + shuffle: true env: math: diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 97f54cc67a..925cf156c8 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -44,6 +44,7 @@ policy: data: add_system_prompt: false + shuffle: false # disable dataloader shuffle, shuffle is handled within the dataset env: sliding_puzzle_game: diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index e7eaef706a..9a07cecfc4 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -73,6 +73,7 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true logger: log_dir: "logs" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index 4906550001..e7f98d5a53 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -73,6 +73,7 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true logger: log_dir: "logs" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 789f4fcbdf..37f10e248d 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -106,6 +106,7 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true logger: log_dir: "logs" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 7d480f58a3..42480f3000 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -106,6 +106,7 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true logger: log_dir: "logs" diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 8863fad45f..1482dd1aed 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -74,6 +74,8 @@ policy: data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true + logger: log_dir: "logs" wandb_enabled: true diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index 102c274bd6..9ffa47a0dc 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -104,6 +105,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ff89e45881..2b72eed335 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index d778674238..232ac0363b 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index ea4f5e66e0..2c3cd2d357 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 9b8ecb47b9..2ee5a10ac4 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 4a21332a07..3fbbaf7800 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 54b60a3cfb..d0b9a8d1ef 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index b0930e76c2..82246da9a9 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -9,6 +9,7 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 @@ -105,6 +106,7 @@ data: prompt_file: examples/prompts/cot.txt system_prompt_file: null dataset_name: OpenMathInstruct-2 + shuffle: true env: math: num_workers: 8 diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index 8535855965..dabcdca5be 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -57,6 +57,7 @@ data: add_bos: true add_eos: true add_generation_prompt: false + shuffle: true logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long wandb_enabled: true diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 2eff0aabf6..edb4775b71 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -57,6 +57,7 @@ data: add_bos: true add_eos: true add_generation_prompt: false + shuffle: true logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp wandb_enabled: true diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 07f5524000..67341716c4 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -101,6 +101,7 @@ data: dataset_name: squad add_bos: true add_eos: true + shuffle: true logger: log_dir: logs/sft-llama3.1-8b-instruct-1n8g-fsdp1 wandb_enabled: true diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index c6311cf357..c45ec45fd7 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -57,6 +57,7 @@ data: add_bos: true add_eos: true add_generation_prompt: false + shuffle: true logger: log_dir: logs/sft-llama3.2-1b-1n8g-fsdp2tp1 wandb_enabled: true diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 54d30dd80b..add0bebddf 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -57,6 +57,7 @@ data: add_bos: true add_eos: true add_generation_prompt: false + shuffle: true logger: log_dir: logs/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt wandb_enabled: true diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 06abcce233..20d4cf6a18 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -125,6 +125,7 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} dataset_name: "HelpSteer3" + shuffle: true logger: log_dir: "logs" # Base directory for all logs diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index a592321cfe..fa295c9375 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -132,6 +132,7 @@ data: add_bos: true add_eos: true add_generation_prompt: false + shuffle: true logger: log_dir: "logs" # Base directory for all logs diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 1f1b88a8a9..aa8fa7d6d6 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -67,6 +67,7 @@ data: add_eos: true add_generation_prompt: true output_key: 'generated_solution' + shuffle: true logger: log_dir: "logs" # Base directory for all logs diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 006ad36a16..f31c2c212c 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -124,6 +124,7 @@ def setup_data( tokenizer: TokenizerType, data_config: DataConfig, env_configs: dict[str, Any], + seed: int, ) -> tuple[ AllTaskProcessedDataset, Optional[AllTaskProcessedDataset], @@ -140,12 +141,12 @@ def setup_data( # Load OpenMathInstruct2Dataset using nemo rl datasets if data_config["dataset_name"] == "OpenMathInstruct-2": print("Loading nvidia/OpenMathInstruct2Dataset for training and validation") - data: Any = OpenMathInstruct2Dataset() + data: Any = OpenMathInstruct2Dataset(seed=seed) elif data_config["dataset_name"] == "DeepScaler": print( "Loading agentica-org/DeepScaleR-Preview-Dataset for training and validation" ) - data: Any = DeepScalerDataset() + data: Any = DeepScalerDataset(seed=seed) else: raise ValueError(f"No processor for dataset {data_config['dataset_name']}.") @@ -236,7 +237,7 @@ def main() -> None: val_dataset, task_to_env, val_task_to_env, - ) = setup_data(tokenizer, config["data"], config["env"]) + ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) ( policy, diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py index c5ccc65524..ca2359d0d2 100644 --- a/examples/run_grpo_sliding_puzzle.py +++ b/examples/run_grpo_sliding_puzzle.py @@ -24,7 +24,7 @@ from transformers import AutoTokenizer from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup -from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.algorithms.utils import get_tokenizer, set_seed from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.games.sliding_puzzle import ( @@ -223,6 +223,8 @@ def main(): init_ray() + set_seed(config["grpo"]["seed"]) + # setup tokenizer tokenizer = get_tokenizer(config["policy"]["tokenizer"]) config["policy"]["generation"] = configure_generation_config( diff --git a/examples/run_sft.py b/examples/run_sft.py index df0d7ce3f7..fc2956b48a 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -95,7 +95,9 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): print("\n▶ Setting up data...") data_cls = data_config["dataset_name"] if data_cls == "open_assistant": - data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") + data = hf_datasets.OasstDataset( + output_dir="/tmp/open_assistant", seed=data_config["seed"] + ) elif data_cls == "squad": data = hf_datasets.SquadDataset() elif data_cls == "prompt_response_dataset": @@ -110,6 +112,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): split=data_config["split"], output_key=data_config["output_key"], prompt_file=data_config["prompt_file"], + seed=data_config["seed"], ) elif data_cls == "openai_format": data = hf_datasets.OpenAIFormatDataset( diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 30ba78f6f2..6a03ffe7ee 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -152,7 +152,7 @@ def setup( train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], - shuffle=True, + shuffle=data_config["shuffle"], collate_fn=partial( dpo_collate_fn, tokenizer=tokenizer, diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fceb2173c6..da453b0b11 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -29,7 +29,7 @@ ClippedPGLossDataDict, ClippedPGLossFn, ) -from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt +from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn from nemo_rl.data.interfaces import ( @@ -84,6 +84,7 @@ class GRPOConfig(TypedDict): val_batch_size: int val_at_start: bool max_val_samples: int + seed: int class GRPOSaveState(TypedDict): @@ -149,6 +150,7 @@ def setup( generation_config = master_config["policy"]["generation"] loss_config = master_config["loss_fn"] grpo_config = master_config["grpo"] + data_config = master_config["data"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] @@ -156,6 +158,9 @@ def setup( "A generation config in the PolicyConfig is required for GRPO" ) + # Set seed for all random number generators + set_seed(grpo_config["seed"]) + # ========================== # Logger # ========================== @@ -179,7 +184,7 @@ def setup( dataloader = StatefulDataLoader( dataset, batch_size=grpo_config["num_prompts_per_step"], - shuffle=False, + shuffle=data_config["shuffle"], collate_fn=rl_collate_fn, drop_last=True, ) diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 9732c84259..1dafc3800d 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -145,7 +145,7 @@ def setup( train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], - shuffle=True, + shuffle=data_config["shuffle"], collate_fn=preference_collate_fn, drop_last=True, ) diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 804909c2c4..7bb5590263 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -134,7 +134,7 @@ def setup( train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], - shuffle=True, + shuffle=data_config["shuffle"], collate_fn=rl_collate_fn, drop_last=True, ) diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 9a9ce4b23a..df14a1546f 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -28,6 +28,7 @@ class DataConfig(TypedDict): add_generation_prompt: NotRequired[bool] add_system_prompt: NotRequired[bool] split: NotRequired[str] + shuffle: NotRequired[bool] class MathDataConfig(DataConfig): diff --git a/nemo_rl/data/hf_datasets/oasst.py b/nemo_rl/data/hf_datasets/oasst.py index a0c19b6909..3ba044e452 100644 --- a/nemo_rl/data/hf_datasets/oasst.py +++ b/nemo_rl/data/hf_datasets/oasst.py @@ -123,8 +123,8 @@ def download_and_process_oasst( class OasstDataset: - def __init__(self, output_dir: str = ".") -> None: - self.formatted_ds = download_and_process_oasst(output_dir) + def __init__(self, output_dir: str = ".", seed: int = 42) -> None: + self.formatted_ds = download_and_process_oasst(output_dir, seed) self.task_spec = TaskDataSpec( task_name="OASST", ) diff --git a/tests/unit/data/test_data_shuffle_reproducity.py b/tests/unit/data/test_data_shuffle_reproducity.py new file mode 100644 index 0000000000..3821423d16 --- /dev/null +++ b/tests/unit/data/test_data_shuffle_reproducity.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from collections import defaultdict + +import pytest +import torch +from torchdata.stateful_dataloader import StatefulDataLoader + +from examples.run_grpo_math import hf_data_processor +from nemo_rl.algorithms.utils import get_tokenizer, set_seed +from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn +from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset +from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec +from nemo_rl.models.policy import TokenizerConfig + +# Test configuration +TOKENIZER_CONFIG: TokenizerConfig = { + "name": "Qwen/Qwen2.5-Math-1.5B-Instruct", + "chat_template": "default", +} + +MAX_BATCHES_TO_TEST = 10 + + +def create_dataloader( + seed: int = 42, max_seq_length: int = 128, batch_size: int = 4 +) -> StatefulDataLoader: + """Create a dataloader with consistent configuration for testing.""" + # Initialize dataset + data = OpenMathInstruct2Dataset(seed=seed) + + # Setup tokenizer + tokenizer = get_tokenizer(TOKENIZER_CONFIG) + + # Configure task specification + math_task_spec = TaskDataSpec( + task_name="math", + prompt_file=f"{os.path.dirname(os.path.abspath(__file__))}/../../../examples/prompts/cot.txt", + system_prompt_file=None, + ) + + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (math_task_spec, hf_data_processor)) + ) + task_data_processors["math"] = (math_task_spec, hf_data_processor) + + dataset = AllTaskProcessedDataset( + dataset=data.formatted_ds["train"].select(range(1000)), + tokenizer=tokenizer, + default_task_data_spec=math_task_spec, + task_data_processors=task_data_processors, + max_seq_length=max_seq_length, + ) + + return StatefulDataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=rl_collate_fn, + drop_last=True, + ) + + +@pytest.mark.parametrize("seed", [42, 24]) +def test_data_shuffle_reproducity_from_start(seed): + """Test that dataloader shuffling is reproducible with the same seed.""" + # Step 1: Set seed and create initial dataloader + set_seed(seed) + original_dataloader = create_dataloader(seed=seed) + + expected_batches = [] + for batch in original_dataloader: + expected_batches.append(batch) + if len(expected_batches) >= MAX_BATCHES_TO_TEST: + break + + # Step 2: to mimic a new experiment: + # set original seed and create new dataloader under the same seed environment + set_seed(seed) + new_dataloader = create_dataloader(seed=seed) + + for i, (expected_batch, actual_batch) in enumerate( + zip(expected_batches, new_dataloader) + ): + assert str(expected_batch) == str(actual_batch), f"Batch {i} is different" + + +@pytest.mark.parametrize("save_state_at_batch", [6, 10]) +def test_data_shuffle_reproducity_from_continue(save_state_at_batch, seed=42): + """Test that dataloader state can be saved and restored for continuation.""" + # Step 1: Set seed and create initial dataloader + set_seed(seed) + original_dataloader = create_dataloader(seed=seed) + + with tempfile.TemporaryDirectory() as temp_dir: + expected_batches = [] + for i, batch in enumerate(original_dataloader): + if ( + i >= save_state_at_batch - 1 + ): # Stop after consuming exactly save_state_at_batch batches + if i == save_state_at_batch - 1: + # Step 2: Save the state at this point + state_file = os.path.join(temp_dir, "dataloader_state.pt") + torch.save(original_dataloader.state_dict(), state_file) + else: + # Step 3: Get the expected continuation from original dataloader + expected_batches.append(batch) + if len(expected_batches) >= MAX_BATCHES_TO_TEST: + break + + # step 4: to mimic a continued experiment: + # set original seed and create new dataloader under the same seed environment + # load the saved state and continue from the saved point + set_seed(seed) + continued_dataloader = create_dataloader(seed=seed) + + state_dict = torch.load(state_file) + continued_dataloader.load_state_dict(state_dict) + + # Step 5: Get batches from the continued dataloader + actual_batches = [] + for batch in continued_dataloader: + if len(actual_batches) >= MAX_BATCHES_TO_TEST: + break + actual_batches.append(batch) + + assert len(actual_batches) == len(expected_batches) + + # Step 6: Compare the batches - they should be identical + for i, (actual_batch, expected_batch) in enumerate( + zip(actual_batches, expected_batches) + ): + assert str(actual_batch) == str(expected_batch), ( + f"Batch {i} from continued dataloader doesn't match expected batch\n" + f"actual_batch['idx']:\t{actual_batch['idx']}\n" + f"expected_batch['idx']:\t{expected_batch['idx']}" + )