Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 29 additions & 49 deletions nemo_rl/models/generation/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def my_run_engine_core(*args, **kwargs):

def monkey_patch_vllm_ray_executor(fp8_config):
if fp8_config.model_parallel_size > 1:
# we patch vllm's _run_workers so that before vllm initalizes the model on each rank, we execute
# we patch vllm's collective_rpc so that before vllm initalizes the model on each rank, we execute
# a ray remote that patches each worker with the required fp8 vllm patches
from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor

original_run_workers = RayDistributedExecutor._run_workers
original_run_workers = RayDistributedExecutor.collective_rpc

def patched_run_workers(self, *args, **kwargs):
global fp8_patches_applied
Expand All @@ -98,7 +98,7 @@ def patched_run_workers(self, *args, **kwargs):

return original_run_workers(self, *args, **kwargs)

RayDistributedExecutor._run_workers = patched_run_workers
RayDistributedExecutor.collective_rpc = patched_run_workers
else:
# for single gpu there is no ray, so just call the patches
apply_fp8_patches(None, fp8_config)
Expand Down Expand Up @@ -225,8 +225,6 @@ def apply_fp8_patches(self, fp8_config):
patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor)
fp8_state.vllm_patches.append(patcher2, patcher3, patcher4)

# Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8)
if global_fp8_config.kv_cache_dtype == "fp8":
# Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates
func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading"
patcher5 = patch(func5_path, process_weights_after_loading_kv)
Expand Down Expand Up @@ -593,7 +591,7 @@ def process_weights_after_loading(self, layer) -> None:
layer.weight.data = weight.data
if hasattr(layer, "weight_scale"):
# Not the first time to call this function, just need to update the data
layer.weight_scale.data = weight_scale.data
layer.weight_scale.copy_(weight_scale.data)
else:
# The first time to call this function, create a new parameter and update the tp status
layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False)
Expand All @@ -609,68 +607,50 @@ def process_weights_after_loading_moe(self, layer) -> None:
new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit.
"""
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major,
requant_weight_ue8m0_inplace,
deepgemm_post_process_fp8_weight_block,
)
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized
assert self.quant_config.activation_scheme == "dynamic"

if self.flashinfer_moe_backend is not None:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
layer.w13_weight_scale_inv.data = swap_w13_to_w31(
layer.w13_weight_scale_inv.data
)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data

# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)

if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
if self.allow_deep_gemm:
w13_weight, w13_weight_scale_inv = deepgemm_post_process_fp8_weight_block(
wq=w13_weight,
ws=w13_weight_scale_inv,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
w2_weight, w2_weight_scale_inv = deepgemm_post_process_fp8_weight_block(
wq=w2_weight,
ws=w2_weight_scale_inv,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)

# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
layer.w13_weight.copy_(w13_weight)
layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv)
layer.w2_weight.copy_(w2_weight)
layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv)


def process_weights_after_loading_kv(self, layer) -> None:
Expand Down
Loading