Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion docs/adding-new-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,125 @@ uv run --extra mcore tools/model_diagnostics/3.check_hf_model_embeddings_untrain
- Thresholds can be adjusted via flags:
- `--near-zero-threshold` (default: `1e-10`)
- `--identical-threshold` (default: `1e-8`)
- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.
- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.

## [4.vllm_precision_compilation_test.py](https://github.com/NVIDIA-NeMo/RL/blob/main/tools/model_diagnostics/4.vllm_precision_compilation_test.py)

Tests vLLM precision compilation by comparing log probabilities across different compilation modes and configurations. This script helps diagnose numerical precision issues that commonly arise when using different vLLM compilation settings. **Note that this is not a strict pass/fail test** - it's designed to help you understand and investigate numerical discrepancies.

```sh
# Example run
uv run --extra vllm tools/model_diagnostics/4.vllm_precision_compilation_test.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

# Typical output shows mixed results:
# Eager and cuda graph mode lps: FAILED - Arrays are different
...
# Eager and cuda graph mode lps with torch inductor precision flag: FAILED - Arrays are different
...
# Eager and cuda graph mode lps with use_inductor disabled: PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
```

See example for model `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
```
====================================================================================================
Eager and cuda graph mode lps (prompt lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 96 / 515 (18.6%)
Max absolute difference among violations: 0.3885002
Max relative difference among violations: 0.20179409
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
-3.443364e-04, nan, nan, nan,
nan, nan, nan, nan,...
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
-3.047717e-04, nan, nan, nan,
nan, nan, nan, nan,...
====================================================================================================
====================================================================================================
Eager and cuda graph mode lps (generation lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001

nan location mismatch:
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
nan, nan],
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
nan, nan],
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
====================================================================================================
...
====================================================================================================
Eager and cuda graph mode lps with torch inductor precision flag (prompt lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 96 / 515 (18.6%)
Max absolute difference among violations: 0.3885002
Max relative difference among violations: 0.20179409
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
-3.443364e-04, nan, nan, nan,
nan, nan, nan, nan,...
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
-3.047717e-04, nan, nan, nan,
nan, nan, nan, nan,...
====================================================================================================
====================================================================================================
Eager and cuda graph mode lps with torch inductor precision flag (generation lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001

nan location mismatch:
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
nan, nan],
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
nan, nan],
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
====================================================================================================
...
Eager and cuda graph mode lps with use_inductor disabled (prompt lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
Eager and cuda graph mode lps with use_inductor disabled (generation lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
```

**What this script tests:**

The script is to compare both prompt and generation logprobs under the following setups:

1. **Eager vs CUDA Graph Mode**: Compares log probabilities between eager execution (ground truth) and CUDA graph compilation mode
- **⚠️ Commonly fails**: This comparison often shows discrepancies due to compilation optimizations
2. **Torch Inductor Precision**: Tests with `TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1` environment variable
- **⚠️ May help**: This flag may help but typically doesn't resolve all the numerical differences
3. **Inductor Disabled**: Verifies that disabling Torch Inductor compilation (`use_inductor=False`) maintains output consistency
- **✅ Usually works well**: This configuration often produces results very close to eager mode
- **Note**: `use_inductor=False` disables Inductor compilation but keeps CUDA graph capture active for compatible operations

**Performance vs Accuracy Trade-offs:**

The different compilation modes offer distinct trade-offs between accuracy and performance:

- **Eager Mode** (`enforce_eager=True`): Highest accuracy (ground truth) but slowest execution
- **CUDA Graph Mode with Inductor Disabled** (`enforce_eager=False` and `compilation_config={"use_inductor": False}`): Near-eager accuracy with significant speedup from CUDA graph optimization
- **CUDA Graph Mode with Inductor Enabled** (`enforce_eager=False` and `compilation_config={"use_inductor": True}`): Potentially fastest execution with custom Triton kernels (since Triton is the current backend of Inductor), but may introduce numerical differences. For accuracy improvement, try the torch inductor precision flag: `export TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1`

**Note**: Performance characteristics vary by model. For example, `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` shows similar speed performance between `use_inductor=True` and `use_inductor=False`, making the accuracy-preserving option preferable.

**Why this matters:**

- **Debugging**: Helps identify which compilation settings cause numerical differences
- **Configuration**: Shows which settings work best for your model
- **Understanding**: Reveals how compilation affects model outputs

**When to use:**

- **Model integration** - understand numerical behavior across vLLM configurations
- **Debugging** - investigate differences between development and production
- **Research** - study compilation strategy impacts on precision

**Interpreting results:**

- **Eager vs CUDA Graph failures are normal** - don't panic if this fails
- **Focus on patterns** - some models are more sensitive than others
- **Use as guidance** - helps choose reliable compilation settings
- **Balance precision vs performance** - choose what works for your use case
4 changes: 4 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,14 @@ policy:
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
# when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy,
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
enforce_eager: False
use_deep_gemm: False
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
vllm_kwargs: {}
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
Expand Down
8 changes: 7 additions & 1 deletion examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ policy:
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: True
enforce_eager: False
vllm_kwargs:
compilation_config:
# when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy,
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
use_inductor: False
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
Expand Down
7 changes: 7 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@
) # pragma: no cover
class VllmAsyncGenerationWorker(BaseVllmGenerationWorker):
def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
from vllm.config import CompilationConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM

# (TODO: zhiyul) Remove this workaround after upgrading vLLM where the compilation_config passing issue is resolved.
if llm_kwargs.get("compilation_config", None):
llm_kwargs["compilation_config"] = CompilationConfig(
**llm_kwargs["compilation_config"]
)

self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs))

async def post_init_async(self):
Expand Down
1 change: 1 addition & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ project-includes = [
"nemo_rl/utils/venvs.py",
"tools/model_diagnostics/1.max_model_len_respected.py",
"tools/model_diagnostics/2.long_generation_decode_vs_prefill.py",
"tools/model_diagnostics/4.vllm_precision_compilation_test.py",
]

# Disable TypedDict mutation errors since TypedDict objects are regular dicts at runtime
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_recipes_and_test_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
"vlm_grpo": "examples/configs/vlm_grpo_3B.yaml",
}

# Configuration keys that are allowed to be added to base configs during testing
# These keys may exist in recipe configs but not in base configs, so we need to
# manually add them to avoid merge conflicts during config validation
ALLOWED_ADDITIONAL_CONFIG_KEYS = ["policy.generation.vllm_kwargs"]


@pytest.fixture
def nightly_test_suite():
Expand Down Expand Up @@ -298,6 +303,18 @@ def test_all_recipes_can_merge_configs_with_base_config(
recipe_yaml_path = os.path.join(recipes_dir, recipe_yaml)
recipe_config = load_config(recipe_yaml_path)
OmegaConf.set_struct(recipe_config, True)

# Work around ALLOWED_ADDITIONAL_CONFIG_KEYS by manually adding allowed keys to the base config
# This prevents merge conflicts when recipe configs contain keys not present in base configs
for key in ALLOWED_ADDITIONAL_CONFIG_KEYS:
if OmegaConf.select(recipe_config, key):
OmegaConf.update(
base_config,
key,
OmegaConf.select(recipe_config, key),
force_add=True,
)

# This will raise a error if the config can't be merged
print(f"Merging {recipe_yaml} with {base_yaml}")
merged_config = OmegaConf.merge(base_config, recipe_config)
Expand Down
Loading