diff --git a/vllm_gaudi/attention/oot_mla.py b/vllm_gaudi/attention/oot_mla.py index 403618211..5be1f1a69 100644 --- a/vllm_gaudi/attention/oot_mla.py +++ b/vllm_gaudi/attention/oot_mla.py @@ -157,10 +157,64 @@ def forward_impl( # during each graph execution def process_weights_after_loading(self, act_dtype: torch.dtype): - MLAAttention.process_weights_after_loading(self, act_dtype) - #super(MLAAttention, self).process_weights_after_loading(act_dtype) - self.W_UV: torch.Tensor = self.W_UV.contiguous() - self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous() + # HPU-specific: when VLLM_HPU_FORCE_CHANNEL_FP8=True (default), block-quantized + # FP8 weights (e.g. kv_b_proj in DeepSeek-R1) are converted to channel-wise FP8. + # After this conversion, weight_scale_inv becomes 1D [N_out] (per-channel) but + # weight_block_size is not cleared. The upstream MLAAttention.process_weights_after_loading + # then calls scaled_dequantize with group_shape=weight_block_size, which fails + # because a 1D scale is incompatible with a 2D block group_shape. + # We handle this by directly dequantizing kv_b_proj for the HPU path. + kv_b_proj = self.kv_b_proj + weight = kv_b_proj.weight + weight_scale_inv = getattr(kv_b_proj, 'weight_scale_inv', None) + + if weight.dtype == torch.float8_e4m3fn and weight_scale_inv is not None: + if weight_scale_inv.dim() == 1: + # Channel-wise FP8 (produced by VLLM_HPU_FORCE_CHANNEL_FP8=True): + # one scale per output channel; dequant via simple broadcast multiply. + ws = weight_scale_inv.view(-1, 1).to(act_dtype) # [N_out, 1] + kv_b_proj_weight = (weight.to(act_dtype) * ws).T + else: + # Block FP8 (force_channel_fp8=False): use HPU block dequant. + from vllm_gaudi.extension.ops import dequant_block_fp8_weight_naive + orig_M = kv_b_proj.orig_M.item() if hasattr(kv_b_proj, 'orig_M') else None + orig_N = kv_b_proj.orig_N.item() if hasattr(kv_b_proj, 'orig_N') else None + kv_b_proj_weight = dequant_block_fp8_weight_naive( + weight, + weight_scale_inv, + kv_b_proj.weight_block_size, + dtype=act_dtype, + original_M=orig_M, + original_N=orig_N, + do_unpad=(orig_M is not None), + ).T + + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), (f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, {self.num_heads=}, " + f"{self.qk_nope_head_dim=}, {self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + self.W_UV = W_UV.transpose(0, 1).contiguous() + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + + from vllm.model_executor.layers.attention.attention import (set_default_quant_scales, + should_load_quant_weights) + quant_method = (self.quant_config.get_quant_method(self, prefix=self.layer_name) + if self.quant_config else None) + if not should_load_quant_weights(quant_method): + set_default_quant_scales(self, register_buffer=False) + else: + # Non-FP8 kv_b_proj: use upstream logic as before. + MLAAttention.process_weights_after_loading(self, act_dtype) + self.W_UV = self.W_UV.contiguous() + self.W_UK_T = self.W_UK_T.contiguous() # NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph, # so we override and always return a new tensor diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 45885ddbb..30c4b1c9f 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -967,6 +967,10 @@ def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False): weight_scale_inv = weight_scale_inv.squeeze(-1) layer.weight.data.copy_(weight) layer.weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) + # Scale is now per-channel, not per-block; clear stale block size to + # prevent downstream code (e.g. scaled_dequantize) from using it as + # a group_shape that is incompatible with the 1D channel-wise scale. + layer.weight_block_size = None htorch.core.mark_step() return layer else: