diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 4f42b79e1d..9acf6a88ef 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -80,11 +80,11 @@ def my_run_engine_core(*args, **kwargs): def monkey_patch_vllm_ray_executor(fp8_config): if fp8_config.model_parallel_size > 1: - # we patch vllm's _run_workers so that before vllm initalizes the model on each rank, we execute + # we patch vllm's collective_rpc so that before vllm initalizes the model on each rank, we execute # a ray remote that patches each worker with the required fp8 vllm patches from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor - original_run_workers = RayDistributedExecutor._run_workers + original_run_workers = RayDistributedExecutor.collective_rpc def patched_run_workers(self, *args, **kwargs): global fp8_patches_applied @@ -98,7 +98,7 @@ def patched_run_workers(self, *args, **kwargs): return original_run_workers(self, *args, **kwargs) - RayDistributedExecutor._run_workers = patched_run_workers + RayDistributedExecutor.collective_rpc = patched_run_workers else: # for single gpu there is no ray, so just call the patches apply_fp8_patches(None, fp8_config) @@ -225,8 +225,6 @@ def apply_fp8_patches(self, fp8_config): patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor) fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) - # Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8) - if global_fp8_config.kv_cache_dtype == "fp8": # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" patcher5 = patch(func5_path, process_weights_after_loading_kv) @@ -593,7 +591,7 @@ def process_weights_after_loading(self, layer) -> None: layer.weight.data = weight.data if hasattr(layer, "weight_scale"): # Not the first time to call this function, just need to update the data - layer.weight_scale.data = weight_scale.data + layer.weight_scale.copy_(weight_scale.data) else: # The first time to call this function, create a new parameter and update the tp status layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False) @@ -609,68 +607,50 @@ def process_weights_after_loading_moe(self, layer) -> None: new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. """ # Lazy import to avoid importing triton too early. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) + from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - expert_weight_is_col_major, - requant_weight_ue8m0_inplace, + deepgemm_post_process_fp8_weight_block, ) from vllm.utils.deep_gemm import ( - get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.activation_scheme == "dynamic" if self.flashinfer_moe_backend is not None: - layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - layer.w13_weight_scale_inv.data = swap_w13_to_w31( - layer.w13_weight_scale_inv.data - ) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data) + else: + w13_weight = layer.w13_weight.data + w13_weight_scale_inv = layer.w13_weight_scale_inv.data + w2_weight = layer.w2_weight.data + w2_weight_scale_inv = layer.w2_weight_scale_inv.data # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv - ) - if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv - ) - - if is_deep_gemm_e8m0_used(): - assert layer.weight_block_size is not None - # Re-quantise the expert weights so their scales are UE8M0. - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.w13_weight.data, - layer.w13_weight_scale_inv.data, - block_sz, + if self.allow_deep_gemm: + w13_weight, w13_weight_scale_inv = deepgemm_post_process_fp8_weight_block( + wq=w13_weight, + ws=w13_weight_scale_inv, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), ) - requant_weight_ue8m0_inplace( - layer.w2_weight.data, - layer.w2_weight_scale_inv.data, - block_sz, + w2_weight, w2_weight_scale_inv = deepgemm_post_process_fp8_weight_block( + wq=w2_weight, + ws=w2_weight_scale_inv, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), ) - - # Ensure column-major TMA alignment expected by DeepGEMM. - if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv - ) - if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv - ) + layer.w13_weight.copy_(w13_weight) + layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv) + layer.w2_weight.copy_(w2_weight) + layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv) def process_weights_after_loading_kv(self, layer) -> None: