diff --git a/docs/adding-new-models.md b/docs/adding-new-models.md index 30fab20a3e..c9951acdc7 100644 --- a/docs/adding-new-models.md +++ b/docs/adding-new-models.md @@ -190,4 +190,125 @@ uv run --extra mcore tools/model_diagnostics/3.check_hf_model_embeddings_untrain - Thresholds can be adjusted via flags: - `--near-zero-threshold` (default: `1e-10`) - `--identical-threshold` (default: `1e-8`) -- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training. \ No newline at end of file +- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training. + +## [4.vllm_precision_compilation_test.py](https://github.com/NVIDIA-NeMo/RL/blob/main/tools/model_diagnostics/4.vllm_precision_compilation_test.py) + +Tests vLLM precision compilation by comparing log probabilities across different compilation modes and configurations. This script helps diagnose numerical precision issues that commonly arise when using different vLLM compilation settings. **Note that this is not a strict pass/fail test** - it's designed to help you understand and investigate numerical discrepancies. + +```sh +# Example run +uv run --extra vllm tools/model_diagnostics/4.vllm_precision_compilation_test.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + +# Typical output shows mixed results: +# Eager and cuda graph mode lps: FAILED - Arrays are different +... +# Eager and cuda graph mode lps with torch inductor precision flag: FAILED - Arrays are different +... +# Eager and cuda graph mode lps with use_inductor disabled: PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001) +``` + +See example for model `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` +``` +==================================================================================================== +Eager and cuda graph mode lps (prompt lps): FAILED - Arrays are different + Detailed error: +Not equal to tolerance rtol=0.001, atol=0.001 + +Mismatched elements: 96 / 515 (18.6%) +Max absolute difference among violations: 0.3885002 +Max relative difference among violations: 0.20179409 + ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01, + -3.443364e-04, nan, nan, nan, + nan, nan, nan, nan,... + DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01, + -3.047717e-04, nan, nan, nan, + nan, nan, nan, nan,... +==================================================================================================== +==================================================================================================== +Eager and cuda graph mode lps (generation lps): FAILED - Arrays are different + Detailed error: +Not equal to tolerance rtol=0.001, atol=0.001 + +nan location mismatch: + ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan, + nan, nan], + [-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,... + DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan, + nan, nan], + [-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,... +==================================================================================================== +... +==================================================================================================== +Eager and cuda graph mode lps with torch inductor precision flag (prompt lps): FAILED - Arrays are different + Detailed error: +Not equal to tolerance rtol=0.001, atol=0.001 + +Mismatched elements: 96 / 515 (18.6%) +Max absolute difference among violations: 0.3885002 +Max relative difference among violations: 0.20179409 + ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01, + -3.443364e-04, nan, nan, nan, + nan, nan, nan, nan,... + DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01, + -3.047717e-04, nan, nan, nan, + nan, nan, nan, nan,... +==================================================================================================== +==================================================================================================== +Eager and cuda graph mode lps with torch inductor precision flag (generation lps): FAILED - Arrays are different + Detailed error: +Not equal to tolerance rtol=0.001, atol=0.001 + +nan location mismatch: + ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan, + nan, nan], + [-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,... + DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan, + nan, nan], + [-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,... +==================================================================================================== +... +Eager and cuda graph mode lps with use_inductor disabled (prompt lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001) +Eager and cuda graph mode lps with use_inductor disabled (generation lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001) +``` + +**What this script tests:** + +The script is to compare both prompt and generation logprobs under the following setups: + +1. **Eager vs CUDA Graph Mode**: Compares log probabilities between eager execution (ground truth) and CUDA graph compilation mode + - **⚠️ Commonly fails**: This comparison often shows discrepancies due to compilation optimizations +2. **Torch Inductor Precision**: Tests with `TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1` environment variable + - **⚠️ May help**: This flag may help but typically doesn't resolve all the numerical differences +3. **Inductor Disabled**: Verifies that disabling Torch Inductor compilation (`use_inductor=False`) maintains output consistency + - **✅ Usually works well**: This configuration often produces results very close to eager mode + - **Note**: `use_inductor=False` disables Inductor compilation but keeps CUDA graph capture active for compatible operations + +**Performance vs Accuracy Trade-offs:** + +The different compilation modes offer distinct trade-offs between accuracy and performance: + +- **Eager Mode** (`enforce_eager=True`): Highest accuracy (ground truth) but slowest execution +- **CUDA Graph Mode with Inductor Disabled** (`enforce_eager=False` and `compilation_config={"use_inductor": False}`): Near-eager accuracy with significant speedup from CUDA graph optimization +- **CUDA Graph Mode with Inductor Enabled** (`enforce_eager=False` and `compilation_config={"use_inductor": True}`): Potentially fastest execution with custom Triton kernels (since Triton is the current backend of Inductor), but may introduce numerical differences. For accuracy improvement, try the torch inductor precision flag: `export TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1` + +**Note**: Performance characteristics vary by model. For example, `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` shows similar speed performance between `use_inductor=True` and `use_inductor=False`, making the accuracy-preserving option preferable. + +**Why this matters:** + +- **Debugging**: Helps identify which compilation settings cause numerical differences +- **Configuration**: Shows which settings work best for your model +- **Understanding**: Reveals how compilation affects model outputs + +**When to use:** + +- **Model integration** - understand numerical behavior across vLLM configurations +- **Debugging** - investigate differences between development and production +- **Research** - study compilation strategy impacts on precision + +**Interpreting results:** + +- **Eager vs CUDA Graph failures are normal** - don't panic if this fails +- **Focus on patterns** - some models are more sensitive than others +- **Use as guidance** - helps choose reliable compilation settings +- **Balance precision vs performance** - choose what works for your use case \ No newline at end of file diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 66c71105ff..8677e4bc64 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -176,10 +176,14 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} + # when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 enforce_eager: False use_deep_gemm: False num_last_layers_in_bf16: 0 num_first_layers_in_bf16: 0 + vllm_kwargs: {} colocated: # true: generation shares training GPUs # false: uses dedicated generation resources diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml index 7513390aaa..95dea397ce 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml @@ -101,7 +101,13 @@ policy: pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} - enforce_eager: True + enforce_eager: False + vllm_kwargs: + compilation_config: + # when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + use_inductor: False colocated: # true: generation shares training GPUs # false: uses dedicated generation resources diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 5260a6e96a..f65275d042 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -36,9 +36,16 @@ ) # pragma: no cover class VllmAsyncGenerationWorker(BaseVllmGenerationWorker): def _create_engine(self, llm_kwargs: dict[str, Any]) -> None: + from vllm.config import CompilationConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.v1.engine.async_llm import AsyncLLM + # (TODO: zhiyul) Remove this workaround after upgrading vLLM where the compilation_config passing issue is resolved. + if llm_kwargs.get("compilation_config", None): + llm_kwargs["compilation_config"] = CompilationConfig( + **llm_kwargs["compilation_config"] + ) + self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs)) async def post_init_async(self): diff --git a/pyrefly.toml b/pyrefly.toml index e9717a1ed0..f4e742cd6c 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -119,6 +119,7 @@ project-includes = [ "nemo_rl/utils/venvs.py", "tools/model_diagnostics/1.max_model_len_respected.py", "tools/model_diagnostics/2.long_generation_decode_vs_prefill.py", + "tools/model_diagnostics/4.vllm_precision_compilation_test.py", ] # Disable TypedDict mutation errors since TypedDict objects are regular dicts at runtime diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 032909295e..e85c3c76fb 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -40,6 +40,11 @@ "vlm_grpo": "examples/configs/vlm_grpo_3B.yaml", } +# Configuration keys that are allowed to be added to base configs during testing +# These keys may exist in recipe configs but not in base configs, so we need to +# manually add them to avoid merge conflicts during config validation +ALLOWED_ADDITIONAL_CONFIG_KEYS = ["policy.generation.vllm_kwargs"] + @pytest.fixture def nightly_test_suite(): @@ -298,6 +303,18 @@ def test_all_recipes_can_merge_configs_with_base_config( recipe_yaml_path = os.path.join(recipes_dir, recipe_yaml) recipe_config = load_config(recipe_yaml_path) OmegaConf.set_struct(recipe_config, True) + + # Work around ALLOWED_ADDITIONAL_CONFIG_KEYS by manually adding allowed keys to the base config + # This prevents merge conflicts when recipe configs contain keys not present in base configs + for key in ALLOWED_ADDITIONAL_CONFIG_KEYS: + if OmegaConf.select(recipe_config, key): + OmegaConf.update( + base_config, + key, + OmegaConf.select(recipe_config, key), + force_add=True, + ) + # This will raise a error if the config can't be merged print(f"Merging {recipe_yaml} with {base_yaml}") merged_config = OmegaConf.merge(base_config, recipe_config) diff --git a/tools/model_diagnostics/4.vllm_precision_compilation_test.py b/tools/model_diagnostics/4.vllm_precision_compilation_test.py new file mode 100644 index 0000000000..276f88943f --- /dev/null +++ b/tools/model_diagnostics/4.vllm_precision_compilation_test.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from contextlib import contextmanager + +import numpy as np +import torch +from vllm import LLM, SamplingParams + + +@contextmanager +def environment(env_vars): + """Context manager to temporarily set environment variables. + + Args: + env_vars (dict): Dictionary of environment variable names and values to set + + Example: + with environment({"CUDA_VISIBLE_DEVICES": "0"}): + # Code here runs with CUDA_VISIBLE_DEVICES=0 + pass + # Environment variables are restored here + """ + # Store original values + original_values = {} + for key in env_vars: + if key in os.environ: + original_values[key] = os.environ[key] + else: + original_values[key] = None + + # Set new values + for key, value in env_vars.items(): + if value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = str(value) + + try: + yield + finally: + # Restore original values + for key, value in original_values.items(): + if value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = value + + +def extract_logprobs(logprobs): + output = [] + for lp in logprobs: + if lp is not None: + output.append(list(lp.values())[0].logprob) + return output + + +def pad_logprobs_list(logprobs_list): + """Pad a list of logprobs lists into a numpy array. + + Args: + logprobs_list (list): List of lists, where each inner list contains logprobs + + Returns: + np.ndarray: Padded numpy array with shape (num_sequences, max_length) + """ + if not logprobs_list: + return np.array([]) + + max_length = max(len(lp) for lp in logprobs_list) + padded_array = np.full((len(logprobs_list), max_length), np.nan, dtype=np.float32) + + for i, lp in enumerate(logprobs_list): + padded_array[i, : len(lp)] = lp + + return padded_array + + +def assert_logprobs_close(actual, expected, test_name, atol=1e-3, rtol=1e-3): + """Assert that two logprobs arrays are close to each other. + + Args: + actual: The actual logprobs array + expected: The expected logprobs array + test_name (str): Name of the test for error messages + atol (float): Absolute tolerance + rtol (float): Relative tolerance + """ + try: + np.testing.assert_allclose(actual, expected, atol=atol, rtol=rtol) + print( + f"{test_name}: PASSED - Arrays are close within tolerance (atol={atol}, rtol={rtol})" + ) + except AssertionError as e: + print("=" * 100) + print(f"{test_name}: FAILED - Arrays are different") + print(f" Detailed error: {e}") + print("=" * 100) + + +def get_logprobs(llm, prompts, sampling_params): + outputs = llm.generate(prompts, sampling_params) + prompt_lps = [] + generation_lps = [] + + # Collect all logprobs + for output in outputs: + prompt_logprobs = extract_logprobs(output.prompt_logprobs) + generation_logprobs = extract_logprobs(output.outputs[0].logprobs) + prompt_lps.append(prompt_logprobs) + generation_lps.append(generation_logprobs) + + # Use common padding function + padded_prompt_lps = pad_logprobs_list(prompt_lps) + padded_generation_lps = pad_logprobs_list(generation_lps) + + return padded_prompt_lps, padded_generation_lps + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + nargs="?", + default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + ) + args = parser.parse_args() + seed = 0 + + sampling_params = SamplingParams( + temperature=1.0, + top_p=1.0, + max_tokens=8192, + prompt_logprobs=0, + logprobs=0, + seed=seed, + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "<|begin▁of▁sentence|><|User|>Think step-by-step to solve the following problem. Output your answer inside of \\\\boxed{} tags.:\n$A B C D$ is a rectangle with $A B=20$ and $B C=3$. A circle with radius 5, centered at the midpoint of $D C$, meets the rectangle at four points: $W, X, Y$, and $Z$. Find the area of quadrilateral $W X Y Z$.\n\nLet's think step-by-step<|Assistant|>\n", + ] + + common_llm_kwargs = { + "model": args.model, + "trust_remote_code": True, + "enable_prefix_caching": True, + "enable_chunked_prefill": True, + } + + eager_prompt_lps, eager_generation_lps = get_logprobs( + LLM(enforce_eager=True, **common_llm_kwargs), # eager mode for ground truth lps + prompts, + sampling_params, + ) + + torch.cuda.empty_cache() + + cuda_graph_prompt_lps, cuda_graph_generation_lps = get_logprobs( + LLM(enforce_eager=False, **common_llm_kwargs), # cuda graph mode + prompts, + sampling_params, + ) + + assert_logprobs_close( + cuda_graph_prompt_lps, + eager_prompt_lps, + "Eager and cuda graph mode lps (prompt lps)", + ) + assert_logprobs_close( + cuda_graph_generation_lps, + eager_generation_lps, + "Eager and cuda graph mode lps (generation lps)", + ) + + torch.cuda.empty_cache() + + with environment(env_vars={"TORCHINDUCTOR_EMULATE_PRECISION_CASTS": "1"}): + cuda_graph_prompt_lps_w_flag, cuda_graph_generation_lps_w_flag = get_logprobs( + LLM(enforce_eager=False, **common_llm_kwargs), + prompts, + sampling_params, + ) + + assert_logprobs_close( + cuda_graph_prompt_lps_w_flag, + eager_prompt_lps, + "Eager and cuda graph mode lps with torch inductor precision flag (prompt lps)", + ) + assert_logprobs_close( + cuda_graph_generation_lps_w_flag, + eager_generation_lps, + "Eager and cuda graph mode lps with torch inductor precision flag (generation lps)", + ) + + torch.cuda.empty_cache() + + ( + cuda_graph_prompt_lps_w_inductor_disabled, + cuda_graph_generation_lps_w_inductor_disabled, + ) = get_logprobs( + LLM( + enforce_eager=False, + compilation_config={"use_inductor": False}, + **common_llm_kwargs, + ), + prompts, + sampling_params, + ) + + assert_logprobs_close( + cuda_graph_prompt_lps_w_inductor_disabled, + eager_prompt_lps, + "Eager and cuda graph mode lps with use_inductor disabled (prompt lps)", + ) + assert_logprobs_close( + cuda_graph_generation_lps_w_inductor_disabled, + eager_generation_lps, + "Eager and cuda graph mode lps with use_inductor disabled (generation lps)", + ) + + +if __name__ == "__main__": + main()