Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions vllm_gaudi/attention/oot_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,64 @@ def forward_impl(

# during each graph execution
def process_weights_after_loading(self, act_dtype: torch.dtype):
MLAAttention.process_weights_after_loading(self, act_dtype)
#super(MLAAttention, self).process_weights_after_loading(act_dtype)
self.W_UV: torch.Tensor = self.W_UV.contiguous()
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous()
# HPU-specific: when VLLM_HPU_FORCE_CHANNEL_FP8=True (default), block-quantized
# FP8 weights (e.g. kv_b_proj in DeepSeek-R1) are converted to channel-wise FP8.
# After this conversion, weight_scale_inv becomes 1D [N_out] (per-channel) but
# weight_block_size is not cleared. The upstream MLAAttention.process_weights_after_loading
# then calls scaled_dequantize with group_shape=weight_block_size, which fails
# because a 1D scale is incompatible with a 2D block group_shape.
# We handle this by directly dequantizing kv_b_proj for the HPU path.
kv_b_proj = self.kv_b_proj
weight = kv_b_proj.weight
weight_scale_inv = getattr(kv_b_proj, 'weight_scale_inv', None)

if weight.dtype == torch.float8_e4m3fn and weight_scale_inv is not None:
if weight_scale_inv.dim() == 1:
# Channel-wise FP8 (produced by VLLM_HPU_FORCE_CHANNEL_FP8=True):
# one scale per output channel; dequant via simple broadcast multiply.
ws = weight_scale_inv.view(-1, 1).to(act_dtype) # [N_out, 1]
kv_b_proj_weight = (weight.to(act_dtype) * ws).T
else:
# Block FP8 (force_channel_fp8=False): use HPU block dequant.
from vllm_gaudi.extension.ops import dequant_block_fp8_weight_naive
orig_M = kv_b_proj.orig_M.item() if hasattr(kv_b_proj, 'orig_M') else None
orig_N = kv_b_proj.orig_N.item() if hasattr(kv_b_proj, 'orig_N') else None
kv_b_proj_weight = dequant_block_fp8_weight_naive(
weight,
weight_scale_inv,
kv_b_proj.weight_block_size,
dtype=act_dtype,
original_M=orig_M,
original_N=orig_N,
do_unpad=(orig_M is not None),
Comment thread
afierka-intel marked this conversation as resolved.
).T

assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, {self.num_heads=}, "
f"{self.qk_nope_head_dim=}, {self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
self.W_UV = W_UV.transpose(0, 1).contiguous()
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()

from vllm.model_executor.layers.attention.attention import (set_default_quant_scales,
should_load_quant_weights)
quant_method = (self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config else None)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)
else:
# Non-FP8 kv_b_proj: use upstream logic as before.
MLAAttention.process_weights_after_loading(self, act_dtype)
self.W_UV = self.W_UV.contiguous()
self.W_UK_T = self.W_UK_T.contiguous()

# NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph,
# so we override and always return a new tensor
Expand Down
4 changes: 4 additions & 0 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,10 @@ def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False):
weight_scale_inv = weight_scale_inv.squeeze(-1)
layer.weight.data.copy_(weight)
layer.weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)
# Scale is now per-channel, not per-block; clear stale block size to
# prevent downstream code (e.g. scaled_dequantize) from using it as
# a group_shape that is incompatible with the 1D channel-wise scale.
layer.weight_block_size = None
htorch.core.mark_step()
return layer
else:
Expand Down
Loading