Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
# Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM.
# This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved.
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP,
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class PY_EXECUTABLES:
# Use NeMo-RL direct dependencies and vllm.
VLLM = f"uv run --locked --extra vllm --directory {git_root}"

# Use NeMo-RL direct dependencies and fsdp.
FSDP = f"uv run --locked --extra fsdp --directory {git_root}"

# Use NeMo-RL direct dependencies and nemo-automodel.
AUTOMODEL = f"uv run --locked --extra automodel --directory {git_root}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,89 +107,6 @@ def patched_run_workers(self, *args, **kwargs):
fp8_patches_applied = True


def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]:
"""Get vLLM-compatible parameter names for Q/K/V FP8 scales.

This function centralizes the naming convention for Q/K/V scale parameters
that vLLM expects. These names must match vLLM's internal parameter structure.

Args:
layer_idx: The transformer layer index (0-based)

Returns:
Dictionary mapping scale types to vLLM parameter names:
- 'q_scale': Q activation scale name
- 'k_scale': K activation scale name
- 'v_scale': V activation scale name

Note:
The q_scale has an extra '.attn.' component compared to k_scale/v_scale.
This matches vLLM's parameter remapping logic in:
vllm.model_executor.model_loader.weight_utils.maybe_remap_kv_scale_name

Example:
>>> get_vllm_qkv_scale_names(0)
{
'q_scale': 'model.layers.0.self_attn.attn.q_scale',
'k_scale': 'model.layers.0.self_attn.k_scale',
'v_scale': 'model.layers.0.self_attn.v_scale'
}
"""
return {
"q_scale": f"model.layers.{layer_idx}.self_attn.attn.q_scale",
"k_scale": f"model.layers.{layer_idx}.self_attn.k_scale",
"v_scale": f"model.layers.{layer_idx}.self_attn.v_scale",
}


def convert_calibration_to_vllm_format(
calibration_results: dict[str, dict[str, float]],
) -> dict[str, float]:
"""Convert NeMo-RL calibration results to vLLM parameter format.

Currently only used by megatron policy worker.
After FP8 KV cache is supported by DTensor path, this function can be reused.

This function transforms the calibration output format (with layer_N keys)
into the flat dictionary format that vLLM expects for parameter loading.

Args:
calibration_results: Dict with keys like "layer_0", "layer_1", etc.
Each value is a dict with keys: "q_scale", "k_scale", "v_scale"
and corresponding float scale values.

Returns:
Flat dictionary mapping vLLM parameter names to scale values.
Keys follow vLLM's naming convention as defined in get_vllm_qkv_scale_names.

Example:
>>> calib = {
... "layer_0": {"q_scale": 1.0, "k_scale": 2.0, "v_scale": 3.0},
... "layer_1": {"q_scale": 1.5, "k_scale": 2.5, "v_scale": 3.5}
... }
>>> convert_calibration_to_vllm_format(calib)
{
'model.layers.0.self_attn.attn.q_scale': 1.0,
'model.layers.0.self_attn.k_scale': 2.0,
'model.layers.0.self_attn.v_scale': 3.0,
'model.layers.1.self_attn.attn.q_scale': 1.5,
'model.layers.1.self_attn.k_scale': 2.5,
'model.layers.1.self_attn.v_scale': 3.5
}
"""
vllm_scales = {}
for layer_key, scales in calibration_results.items():
# Extract layer index from "layer_N" format
layer_idx = int(layer_key.split("_")[1])
param_names = get_vllm_qkv_scale_names(layer_idx)

vllm_scales[param_names["q_scale"]] = scales["q_scale"]
vllm_scales[param_names["k_scale"]] = scales["k_scale"]
vllm_scales[param_names["v_scale"]] = scales["v_scale"]

return vllm_scales


def apply_fp8_patches(self, fp8_config):
global global_fp8_config, fp8_patches_applied
assert not fp8_patches_applied
Expand Down
96 changes: 96 additions & 0 deletions nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]:
"""Get vLLM-compatible parameter names for Q/K/V FP8 scales.

This function centralizes the naming convention for Q/K/V scale parameters
that vLLM expects. These names must match vLLM's internal parameter structure.

Args:
layer_idx: The transformer layer index (0-based)

Returns:
Dictionary mapping scale types to vLLM parameter names:
- 'q_scale': Q activation scale name
- 'k_scale': K activation scale name
- 'v_scale': V activation scale name

Note:
The q_scale has an extra '.attn.' component compared to k_scale/v_scale.
This matches vLLM's parameter remapping logic in:
vllm.model_executor.model_loader.weight_utils.maybe_remap_kv_scale_name

Example:
>>> get_vllm_qkv_scale_names(0)
{
'q_scale': 'model.layers.0.self_attn.attn.q_scale',
'k_scale': 'model.layers.0.self_attn.k_scale',
'v_scale': 'model.layers.0.self_attn.v_scale'
}
"""
return {
"q_scale": f"model.layers.{layer_idx}.self_attn.attn.q_scale",
"k_scale": f"model.layers.{layer_idx}.self_attn.k_scale",
"v_scale": f"model.layers.{layer_idx}.self_attn.v_scale",
}


def convert_calibration_to_vllm_format(
calibration_results: dict[str, dict[str, float]],
) -> dict[str, float]:
"""Convert NeMo-RL calibration results to vLLM parameter format.

Currently only used by megatron policy worker.
After FP8 KV cache is supported by DTensor path, this function can be reused.

This function transforms the calibration output format (with layer_N keys)
into the flat dictionary format that vLLM expects for parameter loading.

Args:
calibration_results: Dict with keys like "layer_0", "layer_1", etc.
Each value is a dict with keys: "q_scale", "k_scale", "v_scale"
and corresponding float scale values.

Returns:
Flat dictionary mapping vLLM parameter names to scale values.
Keys follow vLLM's naming convention as defined in get_vllm_qkv_scale_names.

Example:
>>> calib = {
... "layer_0": {"q_scale": 1.0, "k_scale": 2.0, "v_scale": 3.0},
... "layer_1": {"q_scale": 1.5, "k_scale": 2.5, "v_scale": 3.5}
... }
>>> convert_calibration_to_vllm_format(calib)
{
'model.layers.0.self_attn.attn.q_scale': 1.0,
'model.layers.0.self_attn.k_scale': 2.0,
'model.layers.0.self_attn.v_scale': 3.0,
'model.layers.1.self_attn.attn.q_scale': 1.5,
'model.layers.1.self_attn.k_scale': 2.5,
'model.layers.1.self_attn.v_scale': 3.5
}
"""
vllm_scales = {}
for layer_key, scales in calibration_results.items():
# Extract layer index from "layer_N" format
layer_idx = int(layer_key.split("_")[1])
param_names = get_vllm_qkv_scale_names(layer_idx)

vllm_scales[param_names["q_scale"]] = scales["q_scale"]
vllm_scales[param_names["k_scale"]] = scales["k_scale"]
vllm_scales[param_names["v_scale"]] = scales["v_scale"]

return vllm_scales
4 changes: 2 additions & 2 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def update_weights_via_ipc_zmq(self) -> bool:
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
)
# Load weights into the model
from nemo_rl.models.generation import fp8
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
Expand Down Expand Up @@ -230,7 +230,7 @@ def _load_model_weights(weights, model_runner):
Returns:
None
"""
from nemo_rl.models.generation import fp8
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _patch_vllm_vit_flash_attn_backend():
# Call init_fp8 when precision is fp8
# (kv_cache_dtype can be fp8/fp8_e4m3 or auto, validated in init_fp8)
if self.cfg["vllm_cfg"]["precision"] == "fp8":
from nemo_rl.models.generation.fp8 import init_fp8
from nemo_rl.models.generation.vllm.quantization.fp8 import init_fp8

fp8_kwargs = init_fp8(
self.cfg["vllm_cfg"], self.model_name, model_parallel_size
Expand Down
11 changes: 7 additions & 4 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@
from_parallel_logits_to_logprobs_packed_sequences,
)
from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.models.generation.fp8 import (
convert_calibration_to_vllm_format,
get_vllm_qkv_scale_names,
)
from nemo_rl.models.generation.interfaces import (
GenerationDatumSpec,
GenerationOutputSpec,
Expand Down Expand Up @@ -2139,6 +2135,10 @@ def _iter_params_with_optional_kv_scales(
This helper is used by both IPC-based streaming and collective broadcast
so that the logic for adding KV scales stays consistent in one place.
"""
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
get_vllm_qkv_scale_names,
)

base_iter = self.megatron_bridge.export_hf_weights(
[self.model],
show_progress=False,
Expand Down Expand Up @@ -2544,6 +2544,9 @@ def calibrate_qkv_fp8_scales(
{ "format": "fp8", "percentile": float, "margin": float,
"layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } }
"""
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
convert_calibration_to_vllm_format,
)

# Allow overriding FP8 max for Q, K, V via environment variables for ease of testing.
# Defaults align with FP8 e4m3 max magnitude.
Expand Down
21 changes: 11 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,26 @@ dependencies = [
]

[project.optional-dependencies]
# Currently unused, but after https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved, we should use this for the "BASE" PYEXECUTABLE
fsdp = [
"flash-attn==2.8.1",
"mamba-ssm",
"causal-conv1d",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.11.2",
]
automodel = [
"nemo-automodel",
# Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular)
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
"vllm==0.11.2", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
"flash-attn==2.8.1",
"mamba-ssm",
"causal-conv1d",
"nv-grouped-gemm",
"transformer-engine[pytorch]==2.8.0",
"deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.11.2",
]
vllm = [
"cuda-python",
Expand All @@ -75,12 +82,6 @@ vllm = [
"deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480",
"vllm==0.11.2",
"num2words>=0.5.14",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"flash-attn==2.8.1",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"mamba-ssm",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"causal-conv1d",
]
mcore = [
# also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network)
Expand All @@ -95,12 +96,12 @@ mcore = [
"transformer-engine[pytorch]==2.8.0",
"megatron-core",
"megatron-bridge",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.11.2",
# Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular)
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
"flash-attn==2.8.1",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.11.2",
]
nemo_gym = ["nemo_gym"]

Expand Down
1 change: 1 addition & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ project-includes = [
"nemo_rl/models/generation/interfaces.py",
"nemo_rl/models/generation/vllm/__init__.py",
"nemo_rl/models/generation/vllm/config.py",
"nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py",
"nemo_rl/models/generation/vllm/utils.py",
"nemo_rl/models/generation/vllm/vllm_backend.py",
"nemo_rl/models/huggingface/__init__.py",
Expand Down
18 changes: 11 additions & 7 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading