diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc02d93e42..382ded03ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -71,6 +71,7 @@ repos: base="examples/configs/dpo.yaml"; for f in examples/configs/recipes/llm/dpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done base="examples/configs/grpo_math_1B.yaml"; for f in examples/configs/recipes/llm/grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done base="examples/configs/sft.yaml"; for f in examples/configs/recipes/llm/sft-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + base="examples/configs/distillation_math.yaml"; for f in examples/configs/recipes/llm/distillation-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - id: configs-minimize-check-vlm name: minimize-check vlm recipes language: system diff --git a/docs/design-docs/generation.md b/docs/design-docs/generation.md index 13d7b3dc65..6890f0b2ac 100644 --- a/docs/design-docs/generation.md +++ b/docs/design-docs/generation.md @@ -16,7 +16,7 @@ The core of the generation system is defined in `interfaces.py`, which establish max_new_tokens: int # Maximum number of tokens to generate temperature: float # Sampling temperature top_p: float # Top-p sampling parameter - top_k: int # Top-k sampling parameter + top_k: int | None # Top-k sampling parameter model_name: str # Name or path of the model ``` diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index 92ae09d8ee..0a5823f5fb 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -103,7 +103,7 @@ policy: &POLICY_BASE #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True bias_activation_fusion: True - defer_fp32_logits: null + defer_fp32_logits: False optimizer: optimizer: "adam" diff --git a/examples/configs/distillation_math_megatron.yaml b/examples/configs/distillation_math_megatron.yaml index 3df59eba84..89ac69fa9a 100644 --- a/examples/configs/distillation_math_megatron.yaml +++ b/examples/configs/distillation_math_megatron.yaml @@ -57,7 +57,7 @@ policy: &POLICY_BASE #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True bias_activation_fusion: True - defer_fp32_logits: null + defer_fp32_logits: False optimizer: optimizer: "adam" diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 72823364cf..57fc613106 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -113,6 +113,7 @@ policy: moe_permute_fusion: false #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True + defer_fp32_logits: False optimizer: optimizer: "adam" @@ -155,6 +156,7 @@ policy: overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" + use_custom_fsdp: false data: max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 976eade4ce..f4a636080c 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -63,7 +63,7 @@ policy: tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true - hf_config_overrides: null + hf_config_overrides: {} train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 # Only used when generating using HF backend @@ -103,7 +103,7 @@ policy: moe_permute_fusion: false #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True - defer_fp32_logits: null + defer_fp32_logits: False optimizer: optimizer: "adam" diff --git a/examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml b/examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml index d2b4ec620f..6496b11c2c 100644 --- a/examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml +++ b/examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml @@ -8,7 +8,6 @@ loss_fn: kl_type: reverse checkpointing: checkpoint_dir: checkpoints/distillation-qwen3-32b-to-4b-base-long - save_period: 10 policy: model_name: Qwen/Qwen3-4B-Base max_total_sequence_length: 20480 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.yaml index a2e1b15d2e..15e14fa380 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.yaml @@ -41,8 +41,6 @@ policy: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: '1' generation: max_new_tokens: 4096 - stop_token_ids: - - 128009 vllm_cfg: precision: fp8 gpu_memory_utilization: 0.5 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml index 809bff5916..2d88c9a100 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml @@ -32,8 +32,6 @@ policy: lr_warmup_init: 5.0e-08 generation: max_new_tokens: 4096 - stop_token_ids: - - 128009 vllm_cfg: precision: fp8 max_model_len: 4096 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml index 162db25460..d9161477af 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml @@ -34,8 +34,6 @@ policy: - 13 generation: max_new_tokens: 4096 - stop_token_ids: - - 128009 vllm_cfg: async_engine: true max_model_len: 4096 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index df9181f660..cc8983a465 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -34,8 +34,6 @@ policy: - 13 generation: max_new_tokens: 4096 - stop_token_ids: - - 128009 vllm_cfg: max_model_len: 4096 data: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index fce039a321..6eb8ed4872 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -14,8 +14,6 @@ policy: make_sequence_length_divisible_by: 1 generation: max_new_tokens: 512 - stop_token_ids: - - 128009 vllm_cfg: max_model_len: 512 data: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml index 48f00c626e..333a06d980 100755 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml @@ -19,8 +19,6 @@ policy: make_sequence_length_divisible_by: 1 generation: max_new_tokens: 512 - stop_token_ids: - - 128009 vllm_cfg: max_model_len: 512 data: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index b5aaf22ceb..0c238df73c 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -38,8 +38,6 @@ policy: - 13 generation: max_new_tokens: 16384 - stop_token_ids: - - 151643 vllm_cfg: tensor_parallel_size: 4 max_model_len: 16384 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 44c2f7f8eb..1fa870751b 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -38,8 +38,6 @@ policy: - 13 generation: max_new_tokens: 16384 - stop_token_ids: - - 151643 vllm_cfg: tensor_parallel_size: 4 max_model_len: 16384 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 98e7eadedd..2e868bccc8 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -37,8 +37,6 @@ policy: - 13 generation: max_new_tokens: 4096 - stop_token_ids: - - 151645 vllm_cfg: tensor_parallel_size: 4 max_model_len: 4096 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml index a42ea746a7..fd0a48a663 100755 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml @@ -37,8 +37,6 @@ policy: - 13 generation: max_new_tokens: 4096 - stop_token_ids: - - 151645 vllm_cfg: tensor_parallel_size: 4 max_model_len: 4096 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index c417c00dbd..bb62bf99ef 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -14,8 +14,6 @@ policy: make_sequence_length_divisible_by: 1 generation: max_new_tokens: 512 - stop_token_ids: - - 151645 vllm_cfg: max_model_len: 512 data: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b742846d2d..4fb749a774 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -35,6 +35,7 @@ policy: dtensor_cfg: enabled: true + env_vars: {} cpu_offload: False sequence_parallel: false activation_checkpointing: false @@ -73,6 +74,7 @@ policy: ## ignored since enabled=false, but needed for testing purposes megatron_cfg: enabled: false + env_vars: {} empty_unused_memory_level: 1 activation_checkpointing: false tensor_model_parallel_size: 1 @@ -90,7 +92,8 @@ policy: moe_router_bias_update_rate: 1e-3 moe_permute_fusion: false #gives ~20% training perf speedup with sequence packing - apply_rope_fusion: True + apply_rope_fusion: True + defer_fp32_logits: False optimizer: optimizer: "adam" diff --git a/examples/configs/sft_openmathinstruct2_megatron.yaml b/examples/configs/sft_openmathinstruct2_megatron.yaml index 696d025976..9532f09c1f 100644 --- a/examples/configs/sft_openmathinstruct2_megatron.yaml +++ b/examples/configs/sft_openmathinstruct2_megatron.yaml @@ -40,6 +40,7 @@ policy: grad_reduce_in_fp32: true overlap_grad_reduce: true overlap_param_gather: true + use_custom_fsdp: false empty_unused_memory_level: 1 enabled: true expert_tensor_parallel_size: 1 diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 1b2bce1401..181736765d 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -93,7 +93,7 @@ policy: moe_permute_fusion: false #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True - defer_fp32_logits: null + defer_fp32_logits: False optimizer: optimizer: "adam" @@ -116,6 +116,10 @@ policy: use_distributed_optimizer: true use_precision_aware_optimizer: true + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + clip_grad: ${policy.max_grad_norm} scheduler: diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index fafb9b47ef..7d21cf70ea 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -77,7 +77,6 @@ policy: logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} algorithm: modified_first_fit_decreasing sequence_length_round: 64 - optimizer: null scheduler: - name: torch.optim.lr_scheduler.LinearLR kwargs: @@ -133,6 +132,7 @@ policy: moe_router_bias_update_rate: 0.0 moe_permute_fusion: false apply_rope_fusion: true + defer_fp32_logits: False optimizer: optimizer: adam lr: 2.0e-07 @@ -147,6 +147,8 @@ policy: sgd_momentum: 0.9 use_distributed_optimizer: true use_precision_aware_optimizer: true + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 clip_grad: ${policy.max_grad_norm} scheduler: start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index ecbe003594..14c3594b95 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -39,7 +39,8 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float ratio_clip_min: float ratio_clip_max: float - ratio_clip_c: float + # Dual-clipping value (should be >1 if enabled; usually set to 3 empirically). None to disable. + ratio_clip_c: float | None use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool truncated_importance_sampling_ratio: float | None diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 73476b7d52..3e40c9d78c 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -12,24 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import NotRequired, TypedDict +from typing import Literal, NotRequired, TypedDict +# TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc +# so that we can type check the configs more rigorously as opposed to saying everything +# is not required. class DataConfig(TypedDict): max_input_seq_length: int - prompt_file: NotRequired[str] - system_prompt_file: NotRequired[str] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] dataset_name: str val_dataset_name: NotRequired[str] add_bos: NotRequired[bool] add_eos: NotRequired[bool] input_key: NotRequired[str] - output_key: NotRequired[str] + output_key: NotRequired[str | None] add_generation_prompt: NotRequired[bool] add_system_prompt: NotRequired[bool] - split: NotRequired[str] - shuffle: NotRequired[bool] - seed: NotRequired[int] + split: NotRequired[str | None] + shuffle: bool + seed: NotRequired[int | None] download_dir: NotRequired[str] train_data_path: NotRequired[str] val_data_paths: NotRequired[dict[str, str]] @@ -40,6 +43,104 @@ class DataConfig(TypedDict): num_workers: NotRequired[int] -class MathDataConfig(DataConfig): +# =============================================================================== +# Eval Dataset Configs +# =============================================================================== +# These configs correspond to the eval datasets in data/datasets/eval_datasets/ +# Note: TypedDict doesn't allow narrowing types in child classes, so each config +# is defined independently with common fields repeated. + + +class MMLUEvalDataConfig(TypedDict): + """Config for MMLU and multilingual MMLU datasets. + + Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of: + AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP, + KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG + """ + + max_input_seq_length: int + dataset_name: Literal[ + "mmlu", + "mmlu_AR-XY", + "mmlu_BN-BD", + "mmlu_DE-DE", + "mmlu_EN-US", + "mmlu_ES-LA", + "mmlu_FR-FR", + "mmlu_HI-IN", + "mmlu_ID-ID", + "mmlu_IT-IT", + "mmlu_JA-JP", + "mmlu_KO-KR", + "mmlu_PT-BR", + "mmlu_ZH-CN", + "mmlu_SW-KE", + "mmlu_YO-NG", + ] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +class MMLUProEvalDataConfig(TypedDict): + """Config for MMLU Pro dataset.""" + + max_input_seq_length: int + dataset_name: Literal["mmlu_pro"] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +class AIMEEvalDataConfig(TypedDict): + """Config for AIME datasets.""" + + max_input_seq_length: int + dataset_name: Literal["aime2024", "aime2025"] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +class GPQAEvalDataConfig(TypedDict): + """Config for GPQA datasets.""" + + max_input_seq_length: int + dataset_name: Literal["gpqa", "gpqa_diamond"] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +class MathEvalDataConfig(TypedDict): + """Config for Math datasets.""" + + max_input_seq_length: int + dataset_name: Literal["math", "math500"] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +class LocalMathEvalDataConfig(TypedDict): + """Config for local math datasets loaded from files. + + dataset_name can be a URL or local file path. + Requires additional fields: problem_key, solution_key, file_format, split. + """ + + max_input_seq_length: int + dataset_name: str # URL or file path problem_key: str solution_key: str + file_format: Literal["csv", "json"] + split: NotRequired[str | None] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + + +# Union type for all eval dataset configs +EvalDataConfigType = ( + MMLUEvalDataConfig + | MMLUProEvalDataConfig + | AIMEEvalDataConfig + | GPQAEvalDataConfig + | MathEvalDataConfig + | LocalMathEvalDataConfig +) diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 3992ba69ff..87aff4d387 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -15,7 +15,7 @@ import io import logging import re -from typing import Any, Optional, TypedDict, Union +from typing import Any, NotRequired, TypedDict, Union import ray import torch @@ -41,8 +41,9 @@ class MathEnvConfig(TypedDict): num_workers: int - stop_strings: Optional[list[str]] # Default stop strings for this env - verifier_type: Optional[str] + stop_strings: NotRequired[list[str] | None] # Default stop strings for this env + # The verifier type. None defaults to "math". + verifier_type: NotRequired[str | None] @contextlib.contextmanager diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py index 0ea3cf64aa..d67255ef1e 100644 --- a/nemo_rl/evals/eval.py +++ b/nemo_rl/evals/eval.py @@ -25,7 +25,7 @@ from transformers import AutoTokenizer from nemo_rl.algorithms.utils import set_seed -from nemo_rl.data import MathDataConfig +from nemo_rl.data import EvalDataConfigType from nemo_rl.data.collate_fn import eval_collate_fn from nemo_rl.data.datasets import AllTaskProcessedDataset from nemo_rl.data.llm_message_utils import get_keys_from_message_log @@ -49,12 +49,17 @@ class EvalConfig(TypedDict): save_path: str | None +# TODO: this should updated, but is left to avoid breaking changes +class _PassThroughMathConfig(TypedDict): + math: MathEnvConfig + + class MasterConfig(TypedDict): eval: EvalConfig generation: GenerationConfig # Fixed: was 'generate' tokenizer: TokenizerConfig # Added missing tokenizer key - data: MathDataConfig - env: MathEnvConfig + data: EvalDataConfigType + env: _PassThroughMathConfig cluster: ClusterConfig diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index d49cf33706..a32c374990 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -158,10 +158,9 @@ async def generate_responses_async( "Generation returned no outputs for a non-empty batch." ) - pad_token_id = policy_generation.cfg.get("pad_token_id", tokenizer.pad_token_id) generation_outputs = BatchedDataDict.from_batches( ordered_batched_data_dicts, - pad_value_dict={"output_ids": pad_token_id, "logprobs": 0.0}, + pad_value_dict={"output_ids": tokenizer.pad_token_id, "logprobs": 0.0}, ) # Extract everything we need from the generation outputs diff --git a/nemo_rl/models/generation/__init__.py b/nemo_rl/models/generation/__init__.py index 6d25872ae5..15eff22f89 100644 --- a/nemo_rl/models/generation/__init__.py +++ b/nemo_rl/models/generation/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import cast from transformers import PreTrainedTokenizerBase @@ -26,7 +27,13 @@ def configure_generation_config( ) -> GenerationConfig: """Apply specific configurations to generation config.""" # tokenizer setting - config["pad_token_id"] = tokenizer.pad_token_id + if "_pad_token_id" in config: + warnings.warn( + "'_pad_token_id' found in generation config and will be overridden with tokenizer.pad_token_id. " + "Note: '_pad_token_id' is intended for internal use and has no effect when set in user-provided configs.", + UserWarning, + ) + config["_pad_token_id"] = tokenizer.pad_token_id if config["stop_token_ids"] is None: config["stop_token_ids"] = [tokenizer.eos_token_id] diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 12e5aecbc1..7b3ed190f5 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -104,9 +104,15 @@ class ResourcesConfig(TypedDict): num_nodes: int +class OptionalResourcesConfig(TypedDict): + # Same as ResourcesConfig, but fields can be null and are validated in grpo.py + gpus_per_node: int | None + num_nodes: int | None + + class ColocationConfig(TypedDict): enabled: bool - resources: NotRequired[ResourcesConfig] + resources: OptionalResourcesConfig class GenerationConfig(TypedDict): @@ -116,12 +122,13 @@ class GenerationConfig(TypedDict): max_new_tokens: int temperature: float top_p: float - top_k: int - model_name: str - stop_token_ids: list[int] - stop_strings: NotRequired[list[str]] - pad_token_id: NotRequired[int] + top_k: int | None + model_name: NotRequired[str] # Not Required b/c GRPO writes this + stop_token_ids: list[int] | None + stop_strings: list[str] | None colocated: NotRequired[ColocationConfig] + # This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config + _pad_token_id: NotRequired[int] class GenerationDatumSpec(TypedDict): diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index f43adf07cf..4efb492cd3 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -89,7 +89,7 @@ def __init__( # Validate sampling parameters early to avoid resource allocation with unsupported configs. # The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly. # However, we allow values above certain thresholds for token filtering purposes. - top_k: int | None = self.cfg.get("top_k") + top_k = self.cfg["top_k"] if top_k is not None and top_k != -1 and top_k < TOP_K_THRESHOLD: raise ValueError( ( @@ -117,6 +117,10 @@ def __init__( missing_keys = [ key for key in VllmConfig.__required_keys__ if key not in self.cfg ] + # Also check for model_name which is required by VllmGenerationWorker but marked as NotRequired in GenerationConfig because it's not expected to be set in the job yaml. + if "model_name" not in self.cfg: + missing_keys.append("model_name") + assert not missing_keys, ( f"VLLM Configuration Error: Missing required keys in VllmConfig.\n" f"Missing keys: {', '.join(missing_keys)}\n" @@ -446,7 +450,7 @@ def generate( # Combine results from all tied worker groups combined: BatchedDataDict[GenerationOutputSpec] = BatchedDataDict.from_batches( - results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + results, pad_value_dict={"output_ids": self.cfg["_pad_token_id"]} ) # Verify the output has all required fields @@ -497,7 +501,7 @@ def generate_text( # Combine results from all tied worker groups combined: BatchedDataDict[GenerationOutputSpec] = BatchedDataDict.from_batches( - results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + results, pad_value_dict={"output_ids": self.cfg["_pad_token_id"]} ) # Verify the output has all required fields diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 7ce826a27f..643d035a5f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -536,7 +536,7 @@ def generate( ) # verify inputs have correct padding - verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) + verify_right_padding(data, pad_value=self.cfg["_pad_token_id"]) # Original input length with padding padded_input_length = input_ids.size(1) @@ -570,7 +570,7 @@ def generate( # Create a new tensor with the right size and fill with padding token full_output = torch.full( - (total_length,), self.cfg["pad_token_id"], dtype=input_ids.dtype + (total_length,), self.cfg["_pad_token_id"], dtype=input_ids.dtype ) # Copy original input (with padding) into the beginning diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index c456c62c03..b50700803a 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -528,7 +528,7 @@ async def generate_async( if len(data["input_ids"]) == 0: return - verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) + verify_right_padding(data, pad_value=self.cfg["_pad_token_id"]) input_ids_batch = data["input_ids"] input_lengths_batch = data["input_lengths"] @@ -636,7 +636,7 @@ async def process_single_sample(sample_idx): # Create output_ids tensor for this single item output_ids_single_item = torch.full( (final_output_tensor_len,), - self.cfg["pad_token_id"], + self.cfg["_pad_token_id"], dtype=original_input_ids_single_row.dtype, device=original_input_ids_single_row.device, ) diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index f8502d347d..f8184dfefa 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -12,26 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, NotRequired, TypedDict, Union +from typing import Any, Literal, NotRequired, TypedDict, Union from nemo_rl.models.generation.interfaces import GenerationConfig +class DTensorConfigDisabled(TypedDict): + enabled: Literal[False] + + class DTensorConfig(TypedDict): - enabled: bool - env_vars: NotRequired[dict[str, str]] + enabled: Literal[True] + env_vars: NotRequired[dict[str, str] | None] _v2: NotRequired[bool] - cpu_offload: NotRequired[bool] - sequence_parallel: NotRequired[bool] - activation_checkpointing: NotRequired[bool] - tensor_parallel_size: NotRequired[int] - context_parallel_size: NotRequired[int] - custom_parallel_plan: NotRequired[str] - clear_cache_every_n_steps: NotRequired[int] + cpu_offload: bool + sequence_parallel: bool + activation_checkpointing: bool + tensor_parallel_size: int + context_parallel_size: int + custom_parallel_plan: str | None + clear_cache_every_n_steps: NotRequired[int | None] + + +class SequencePackingConfigDisabled(TypedDict): + enabled: Literal[False] class SequencePackingConfig(TypedDict): - enabled: bool + enabled: Literal[True] train_mb_tokens: int # Not required because some algorithms like SFT don't calculate log probs logprob_mb_tokens: NotRequired[int] @@ -73,7 +81,7 @@ class MegatronSchedulerConfig(TypedDict): end_weight_decay: float weight_decay_incr_style: str lr_decay_style: str - lr_decay_iters: int + lr_decay_iters: NotRequired[int] lr_warmup_iters: int lr_warmup_init: float @@ -87,26 +95,34 @@ class MegatronDDPConfig(TypedDict): data_parallel_sharding_strategy: str +# Type exists to be lax if not specified +class MegatronConfigDisabled(TypedDict): + enabled: Literal[False] + + class MegatronConfig(TypedDict): - enabled: bool - env_vars: NotRequired[dict[str, str]] + enabled: Literal[True] + env_vars: NotRequired[dict[str, str] | None] + # 1 is the minimum recommendation for RL since we almost always need to offload before beginning generation. + # Setting to 0 is faster, but you are more likely to run out of GPU memory. In SFT/DPO, the default is 0. empty_unused_memory_level: int activation_checkpointing: bool - converter_type: str tensor_model_parallel_size: int pipeline_model_parallel_size: int - num_layers_in_first_pipeline_stage: int - num_layers_in_last_pipeline_stage: int + num_layers_in_first_pipeline_stage: int | None + num_layers_in_last_pipeline_stage: int | None context_parallel_size: int pipeline_dtype: str sequence_parallel: bool freeze_moe_router: bool expert_tensor_parallel_size: int expert_model_parallel_size: int + # If True, defer the casting of logits to float32 until the backward pass. + # If you are using logprob_chunk_size, you must set this to True. defer_fp32_logits: NotRequired[bool] - optimizer: NotRequired[MegatronOptimizerConfig] - scheduler: NotRequired[MegatronSchedulerConfig] + optimizer: MegatronOptimizerConfig + scheduler: MegatronSchedulerConfig distributed_data_parallel_config: MegatronDDPConfig @@ -114,7 +130,7 @@ class TokenizerConfig(TypedDict): name: str chat_template: NotRequired[str] # Arguments to pass to tokenizer.apply_chat_template(...). This can be used to pass kwargs like enable_thinking=true - chat_template_kwargs: NotRequired[dict[str, Any]] + chat_template_kwargs: NotRequired[dict[str, Any] | None] class PytorchOptimizerConfig(TypedDict): @@ -125,24 +141,29 @@ class PytorchOptimizerConfig(TypedDict): class SinglePytorchSchedulerConfig(TypedDict): name: str kwargs: dict[str, Any] - milestones: NotRequired[list[int]] # Used in SequentialLR configuration + + +class SinglePytorchMilestonesConfig(TypedDict): + milestones: list[int] # Used in SequentialLR configuration SchedulerMilestones = dict[str, list[int]] +class DynamicBatchingConfigDisabled(TypedDict): + enabled: Literal[False] + + class DynamicBatchingConfig(TypedDict): # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length # responses are sorted by sequence length and bucketed into microbatches with a total # amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the # training and logprob stages respectively. - enabled: bool - - ## required if enabled is true - train_mb_tokens: NotRequired[int] - logprob_mb_tokens: NotRequired[int] - sequence_length_round: NotRequired[int] + enabled: Literal[True] + train_mb_tokens: int + logprob_mb_tokens: NotRequired[int] # Only used for some algorithms + sequence_length_round: int class PolicyConfig(TypedDict): @@ -151,21 +172,29 @@ class PolicyConfig(TypedDict): train_global_batch_size: int train_micro_batch_size: int logprob_batch_size: NotRequired[int] - logprob_chunk_size: NotRequired[int] + # If set, log probability computation is chunked along the sequence dimension to avoid GPU OOM (especially during backward pass). + # Within each chunk loop, logits casting (from float16/bfloat16 to float32) is done to prevent holding the entire float32 logits tensor in memory. + # If None, chunking is disabled and the full sequence is processed at once. + logprob_chunk_size: NotRequired[int | None] generation: NotRequired[GenerationConfig] generation_batch_size: NotRequired[ int ] # used in static batched (framework) generation precision: str reward_model_cfg: NotRequired[RewardModelConfig] - dtensor_cfg: DTensorConfig - megatron_cfg: NotRequired[MegatronConfig] + dtensor_cfg: DTensorConfig | DTensorConfigDisabled + megatron_cfg: NotRequired[MegatronConfig | MegatronConfigDisabled] hf_config_overrides: NotRequired[dict[str, Any]] - dynamic_batching: DynamicBatchingConfig - sequence_packing: NotRequired[SequencePackingConfig] + dynamic_batching: DynamicBatchingConfig | DynamicBatchingConfigDisabled + sequence_packing: NotRequired[SequencePackingConfig | SequencePackingConfigDisabled] make_sequence_length_divisible_by: int max_total_sequence_length: int - max_grad_norm: NotRequired[Union[float, int]] + # This sets the clipping norm for the DTensorPolicyWorkers (Megatron's is called clip_grad) + max_grad_norm: NotRequired[float | int | None] refit_buffer_size_gb: NotRequired[float] - optimizer: NotRequired[PytorchOptimizerConfig] - scheduler: NotRequired[list[SinglePytorchSchedulerConfig] | SchedulerMilestones] + optimizer: NotRequired[PytorchOptimizerConfig | None] + scheduler: NotRequired[ + list[SinglePytorchSchedulerConfig | SinglePytorchMilestonesConfig] + | SchedulerMilestones + | None + ] diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 2a20ddcd55..b69555ae15 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -582,7 +582,7 @@ def generate( assert self.cfg["generation"] is not None, "Generation config is not set" result: BatchedDataDict[GenerationOutputSpec] = BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures), - pad_value_dict={"output_ids": self.cfg["generation"]["pad_token_id"]}, + pad_value_dict={"output_ids": self.cfg["generation"]["_pad_token_id"]}, ) # Verify the output has all required fields @@ -733,10 +733,10 @@ def save_checkpoint( else: if ( checkpointing_cfg is not None - and checkpointing_cfg.get("model_save_format") == "safetensors" + and checkpointing_cfg.get("model_save_format", None) is not None ): raise ValueError( - "safetensors is only supported with DTensorPolicyWorkerV2 (_v2=true)." + "model_save_format must be None or omitted if using DTensorPolicyWorker (_v2=False)." ) futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3af30b5554..547f385700 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -289,7 +289,7 @@ def re_enable_float32_expert_bias(megatron_model): data_parallel_random_init=cfg.rng.data_parallel_random_init, pre_wrap_hook=pre_wrap_hook, wrap_cast_model_output_to_fp32=( - not policy_cfg["megatron_cfg"].get("defer_fp32_logits", None) + not policy_cfg["megatron_cfg"].get("defer_fp32_logits", False) ), ) if load_optimizer: @@ -637,6 +637,14 @@ def __init__( assert optimizer_offload_fraction == 1.0, ( "Currently for optimizer offloading, only optimizer_offload_fraction=1.0 is supported" ) + if ( + "logprob_chunk_size" in self.cfg + and self.cfg["logprob_chunk_size"] is not None + and self.cfg["logprob_chunk_size"] > 0 + ): + assert self.cfg["megatron_cfg"]["defer_fp32_logits"], ( + "defer_fp32_logits must be True if logprob_chunk_size is set" + ) checkpoint_config = CheckpointConfig( save_interval=100, @@ -757,7 +765,7 @@ def __init__( overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, pre_wrap_hook=self.megatron_cfg.rng.data_parallel_random_init, wrap_cast_model_output_to_fp32=( - not self.cfg["megatron_cfg"].get("defer_fp32_logits", None) + not self.cfg["megatron_cfg"].get("defer_fp32_logits", False) ), ) print("Loading the Reference Model") diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 74ce5ac0cb..05e0ee2f3a 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -44,7 +44,7 @@ class CheckpointingConfig(TypedDict): the metric should be taken from the validation or training metrics. higher_is_better (bool): Whether higher values of the metric indicate better performance. keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. - model_save_format (str): Format for saving model ("torch_save" or "safetensors"). + model_save_format (str | None): Format for saving model (v2 allowed values: "torch_save" or "safetensors", v1 allowed values: None). save_consolidated (bool): Whether to save consolidated checkpoints (for HF compatibility). model_cache_dir (str): Directory for model cache (for safetensors format). model_repo_id (str): Repository ID for the model (for safetensors format). @@ -59,7 +59,7 @@ class CheckpointingConfig(TypedDict): keep_top_k: NotRequired[int] checkpoint_must_save_by: NotRequired[str | None] # New nemo-automodel integration fields - model_save_format: NotRequired[str] # Default: "safetensors" + model_save_format: NotRequired[str | None] # Default: "safetensors" save_consolidated: NotRequired[bool] # Default: False model_cache_dir: NotRequired[str] # Default: "" model_repo_id: NotRequired[str] # Default: "" diff --git a/pyproject.toml b/pyproject.toml index 348a214b6c..2024533066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,7 +246,7 @@ ignore = ["D417", "D10", "F841"] convention = "google" # Section to exclude errors for different file types -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Ignore all directories named `tests`. "tests/**" = ["D"] # Ignore all files that end in `_test.py`. diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index f1846477f9..2674f73a6d 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -437,7 +437,7 @@ async def _generate_async(vllm_policy, tokenizer, test_input_data, greedy=False) # Extract in correct order outputs = [item for _, item in collected_indexed_outputs] - pad_token_id = vllm_policy.cfg.get("pad_token_id", tokenizer.pad_token_id) + pad_token_id = vllm_policy.cfg.get("_pad_token_id", tokenizer.pad_token_id) outputs = BatchedDataDict.from_batches( outputs, pad_value_dict={"output_ids": pad_token_id, "logprobs": 0.0}, diff --git a/tests/unit/models/generation/test_vllm_large_model.py b/tests/unit/models/generation/test_vllm_large_model.py index cc807509d3..89eaece234 100644 --- a/tests/unit/models/generation/test_vllm_large_model.py +++ b/tests/unit/models/generation/test_vllm_large_model.py @@ -168,7 +168,7 @@ async def test_vllm_large_model( # Extract in correct order outputs = [item for _, item in collected_indexed_outputs] - pad_token_id = async_policy.cfg.get("pad_token_id", tokenizer.pad_token_id) + pad_token_id = async_policy.cfg.get("_pad_token_id", tokenizer.pad_token_id) outputs = BatchedDataDict.from_batches( outputs, pad_value_dict={"output_ids": pad_token_id, "logprobs": 0.0}, diff --git a/tests/unit/test_config_validation.py b/tests/unit/test_config_validation.py index 011076c5ee..349a024ab0 100644 --- a/tests/unit/test_config_validation.py +++ b/tests/unit/test_config_validation.py @@ -14,225 +14,121 @@ import glob import os -import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type, Union, get_type_hints +from typing import Any, Dict, Type import pytest from omegaconf import OmegaConf -from typing_extensions import NotRequired - -from nemo_rl.algorithms.distillation import DistillationConfig -from nemo_rl.algorithms.dpo import DPOConfig -from nemo_rl.algorithms.grpo import GRPOConfig, GRPOLoggerConfig -from nemo_rl.algorithms.sft import SFTConfig -from nemo_rl.data import DataConfig -from nemo_rl.distributed.virtual_cluster import ClusterConfig -from nemo_rl.models.policy import PolicyConfig -from nemo_rl.utils.checkpoint import CheckpointingConfig +from pydantic import TypeAdapter, ValidationError + +from nemo_rl.algorithms.distillation import MasterConfig as DistillationMasterConfig +from nemo_rl.algorithms.dpo import MasterConfig as DPOMasterConfig +from nemo_rl.algorithms.grpo import MasterConfig as GRPOMasterConfig +from nemo_rl.algorithms.rm import MasterConfig as RMMasterConfig +from nemo_rl.algorithms.sft import MasterConfig as SFTMasterConfig +from nemo_rl.evals.eval import MasterConfig as EvalMasterConfig from nemo_rl.utils.config import load_config_with_inheritance -from nemo_rl.utils.logger import LoggerConfig # All tests in this module should run first pytestmark = pytest.mark.run_first +if not OmegaConf.has_resolver("mul"): + OmegaConf.register_new_resolver("mul", lambda a, b: a * b) -def get_keys_from_typeddict(typed_dict_class: dict) -> Set[str]: - """Extract required keys from a TypedDict class, excluding NotRequired fields.""" - type_hints = get_type_hints(typed_dict_class, include_extras=True) - required_keys = set() - optional_keys = set() - - for key, annotation in type_hints.items(): - # Check if the field is marked as NotRequired - if hasattr(annotation, "__origin__") and (annotation.__origin__ is NotRequired): - optional_keys.add(key) - - ## check for Optional fields - elif ( - hasattr(annotation, "__origin__") - and annotation.__origin__ is Union - and type(None) in annotation.__args__ - ): - raise ValueError( - f"Please use the NotRequired annotation instead of Optional for key {key}" - ) - else: - required_keys.add(key) - - return required_keys, optional_keys - - -def validate_nested_config_section( - config_dict: Dict[str, Any], config_class: Type, section_path: str -) -> List[str]: - """Recursively validate a config section and its nested TypedDict fields.""" - errors = [] - type_hints = get_type_hints(config_class, include_extras=True) - - for key, annotation in type_hints.items(): - current_path = f"{section_path}.{key}" if section_path else key - - # Check if the field is marked as NotRequired - is_optional = hasattr(annotation, "__origin__") and ( - annotation.__origin__ is NotRequired - ) - - # If the key is not in the config and it's required, add an error - if key not in config_dict: - if not is_optional: - errors.append(f"Missing required key in {section_path}: {key}") - continue - - # Get the value from the config - value = config_dict[key] - # If the annotation is a TypedDict (nested config), validate it recursively - if hasattr(annotation, "__annotations__") and isinstance(value, dict): - # This is a nested TypedDict, validate it recursively - nested_errors = validate_nested_config_section( - value, annotation, current_path +def validate_config_section( + section_config: Dict[str, Any], + config_class: Type, + config_file: str, +) -> None: + """Validate a config section against its TypedDict class using Pydantic. + + Raises AssertionError with formatted error messages if validation fails. + """ + if not isinstance(section_config, dict): + raise TypeError("Config must be a dictionary") + + # Use Pydantic's TypeAdapter to validate the TypedDict + adapter = TypeAdapter(config_class) + try: + adapter.validate_python(section_config) + except ValidationError as e: + # Format errors nicely with actual values + error_messages = [] + for error in e.errors(): + path_parts = [] + if error["loc"]: + path_parts.extend(str(loc) for loc in error["loc"]) + path = ".".join(path_parts) if path_parts else "root" + + # Only include the actual input value for non-missing fields + # For missing fields, the 'input' is the parent dict which is confusing + input_info = "" + if "input" in error and error["type"] != "missing": + input_value = error.get("input") + # Truncate very long values for readability + input_str = str(input_value) + if len(input_str) > 100: + input_str = input_str[:97] + "..." + input_info = f" (got: {input_str})" + + error_messages.append( + f" {path}: {error['msg']} (type={error['type']}){input_info}" ) - errors.extend(nested_errors) - elif hasattr(annotation, "__origin__") and annotation.__origin__ is Optional: - # Handle Optional[TypedDict] case - if ( - value is not None - and hasattr(annotation.__args__[0], "__annotations__") - and isinstance(value, dict) - ): - nested_errors = validate_nested_config_section( - value, annotation.__args__[0], current_path - ) - errors.extend(nested_errors) - - # Check for extra keys (keys in config that are not in the TypedDict) - required_keys, optional_keys = get_keys_from_typeddict(config_class) - all_valid_keys = required_keys | optional_keys - - for key in config_dict.keys(): - if key not in all_valid_keys: - errors.append(f"Extra key in {section_path}: {key}") - - return errors + config_info = f"\n\nConfig file: {config_file}" if config_file else "" + raise AssertionError( + f"Config validation failed:{config_info}\n" + "\n".join(error_messages) + ) from e -def validate_config_section( - config_dict: Dict[str, Any], config_class: dict, section_name: str -) -> List[str]: - """Validate a specific section of a config against its TypedDict class.""" - errors = [] - required_keys, optional_keys = get_keys_from_typeddict(config_class) - - if section_name not in config_dict: - errors.append(f"Missing required section: {section_name}") - return errors - section_config = config_dict[section_name] - if not isinstance(section_config, dict): - errors.append(f"Section {section_name} must be a dictionary") - return errors +absolute_path = os.path.abspath(__file__) +configs_dir = Path( + os.path.join(os.path.dirname(absolute_path), "../../examples/configs") +).resolve() +config_files = glob.glob(str(configs_dir / "**/*.yaml"), recursive=True) +assert len(config_files) > 0, "No config files found" - # Use the new recursive validation function - nested_errors = validate_nested_config_section( - section_config, config_class, section_name - ) - errors.extend(nested_errors) - return errors +@pytest.mark.parametrize("config_file", config_files) +def test_all_config_files_have_required_keys(config_file): + """Test that all config files in examples/configs have all required keys for their respective sections.""" + print(f"\nValidating config file: {config_file}") + + # Load the config file with inheritance + config = load_config_with_inheritance(config_file) + config_dict = OmegaConf.to_container(config, resolve=True) + + if config_dict is None: + raise AssertionError(f"Config file {config_file} is empty or invalid") + + # Determine which MasterConfig to use based on the config contents + master_config_class = None + config_type = None + + if "/evals/" in config_file: + master_config_class = EvalMasterConfig + config_type = "eval" + elif "distillation" in config_dict: + master_config_class = DistillationMasterConfig + config_type = "distillation" + elif "dpo" in config_dict: + master_config_class = DPOMasterConfig + config_type = "dpo" + elif "sft" in config_dict: + master_config_class = SFTMasterConfig + config_type = "sft" + elif "grpo" in config_dict: + master_config_class = GRPOMasterConfig + config_type = "grpo" + elif "rm" in config_dict: + master_config_class = RMMasterConfig + config_type = "rm" + else: + raise AssertionError( + f"Could not determine algorithm type for config {config_file}." + ) -def test_all_config_files_have_required_keys(): - """Test that all config files in examples/configs have all required keys for their respective sections.""" - if not OmegaConf.has_resolver("mul"): - OmegaConf.register_new_resolver("mul", lambda a, b: a * b) - - absolute_path = os.path.abspath(__file__) - configs_dir = Path( - os.path.join(os.path.dirname(absolute_path), "../../examples/configs") - ) - - # Get all YAML config files - config_files = glob.glob(str(configs_dir / "**/*.yaml"), recursive=True) - - assert len(config_files) > 0, "No config files found" - - all_errors = [] - - for config_file in config_files: - print(f"\nValidating config file: {config_file}") - - try: - # Load the config file with inheritance - config = load_config_with_inheritance(config_file) - config_dict = OmegaConf.to_container(config, resolve=True) - - if config_dict is None: - all_errors.append(f"Config file {config_file} is empty or invalid") - continue - - # Validate each section against its corresponding config class - section_validations = [ - ("policy", PolicyConfig), - ("data", DataConfig), - ("cluster", ClusterConfig), - ("checkpointing", CheckpointingConfig), - ] - - # Add algorithm-specific validation - if "distillation" in config_dict: - section_validations.extend( - [("distillation", DistillationConfig), ("logger", LoggerConfig)] - ) - # Distillation also has a loss_fn section - if "loss_fn" in config_dict: - from nemo_rl.algorithms.loss_functions import DistillationLossConfig - - section_validations.append(("loss_fn", DistillationLossConfig)) - elif "dpo" in config_dict: - section_validations.extend( - [("dpo", DPOConfig), ("logger", LoggerConfig)] - ) - elif "sft" in config_dict: - section_validations.extend( - [("sft", SFTConfig), ("logger", LoggerConfig)] - ) - elif "grpo" in config_dict: - section_validations.extend( - [("grpo", GRPOConfig), ("logger", GRPOLoggerConfig)] - ) - # GRPO also has a loss_fn section - if "loss_fn" in config_dict: - from nemo_rl.algorithms.loss_functions import ClippedPGLossConfig - - section_validations.append(("loss_fn", ClippedPGLossConfig)) - else: - warnings.warn( - f"Could not determine algorithm type for config {config_file}. Continuing..." - ) - continue - - # Validate each section - for section_name, config_class in section_validations: - errors = validate_config_section( - config_dict, config_class, section_name - ) - for error in errors: - all_errors.append(f"{config_file}: {error}") - - # Additional validation for GRPO configs that have an 'env' section - if "grpo" in config_dict and "env" in config_dict: - if not isinstance(config_dict["env"], dict): - all_errors.append( - f"{config_file}: env section must be a dictionary" - ) - - except Exception as e: - all_errors.append(f"Error processing {config_file}: {str(e)}") - - # If there are any errors, fail the test with detailed error messages - if all_errors: - error_message = "\n".join(all_errors) - pytest.fail(f"Config validation failed:\n{error_message}") - - print(f"\n✅ Successfully validated {len(config_files)} config files") + # Validate the entire config using the appropriate MasterConfig + validate_config_section(config_dict, master_config_class, config_file) diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 48f44b8349..5fc984e246 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -36,6 +36,7 @@ "grpo": "examples/configs/grpo_math_1B.yaml", "vlm_grpo": "examples/configs/vlm_grpo_3B.yaml", "distillation": "examples/configs/distillation_math.yaml", + "rm": "examples/configs/rm.yaml", "dapo": "examples/configs/grpo_math_1B.yaml", } @@ -247,41 +248,3 @@ def test_all_recipes_start_with_algo_hyphen(all_recipe_yaml_rel_paths): assert algo in expected_algos, ( f"Recipe {recipe_yaml} has unexpected algo {algo}" ) - - -@pytest.mark.parametrize("algo, algo_base_yaml", ALGO_MAPPING_TO_BASE_YAML.items()) -def test_all_recipes_can_merge_configs_with_base_config( - all_recipe_yaml_rel_paths, all_test_suites, algo, algo_base_yaml -): - from omegaconf import OmegaConf - - from nemo_rl.utils.config import load_config - - base_yaml = os.path.join(project_root, algo_base_yaml) - base_config = OmegaConf.load(base_yaml) - # Would result in an error if we couldn't merge our config with the recipe's config - OmegaConf.set_struct(base_config, True) - for recipe_yaml in all_recipe_yaml_rel_paths: - if not os.path.basename(recipe_yaml).startswith(algo): - # Skipping here b/c we test that all recipes start with the algo-hyphen in - # test_all_recipes_start_with_algo_hyphen() - continue - recipe_yaml_path = os.path.join(recipes_dir, recipe_yaml) - recipe_config = load_config(recipe_yaml_path) - OmegaConf.set_struct(recipe_config, True) - - # Work around ALLOWED_ADDITIONAL_CONFIG_KEYS by manually adding allowed keys to the base config - # This prevents merge conflicts when recipe configs contain keys not present in base configs - for key in ALLOWED_ADDITIONAL_CONFIG_KEYS: - if OmegaConf.select(recipe_config, key): - OmegaConf.update( - base_config, - key, - OmegaConf.select(recipe_config, key), - force_add=True, - ) - - # This will raise a error if the config can't be merged - print(f"Merging {recipe_yaml} with {base_yaml}") - merged_config = OmegaConf.merge(base_config, recipe_config) - print(merged_config) diff --git a/tools/config_cli.py b/tools/config_cli.py index 38500cb02a..04780e7747 100755 --- a/tools/config_cli.py +++ b/tools/config_cli.py @@ -46,10 +46,12 @@ tools/config_cli.py minimize examples/configs/dpo.yaml examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml --in-place # Minimize all llm the configs: - for algo in grpo dpo sft; do + for algo in grpo dpo sft distillation; do base_config=examples/configs/${algo}.yaml if [[ ${algo} == grpo ]]; then base_config=examples/configs/grpo_math_1B.yaml + elif [[ ${algo} == distillation ]]; then + base_config=examples/configs/distillation_math.yaml fi for recipe in examples/configs/recipes/llm/${algo}-*.yaml; do tools/config_cli.py minimize $base_config $recipe --in-place