diff --git a/tests/unit_tests/ops/test_hpu_compressed_tensors.py b/tests/unit_tests/ops/test_hpu_compressed_tensors.py index 4166bcc494..bd5180cfc9 100644 --- a/tests/unit_tests/ops/test_hpu_compressed_tensors.py +++ b/tests/unit_tests/ops/test_hpu_compressed_tensors.py @@ -7,7 +7,8 @@ from unittest.mock import MagicMock from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig from vllm_gaudi.ops.hpu_compressed_tensors import (HPUCompressedTensorsLinearMethod, HPUCompressedTensorsW8A8Fp8, - HPUCompressedTensorsWNA16, HPUCompressedTensorsWNA16MoEMethod) + HPUCompressedTensorsWNA16, HPUCompressedTensorsWNA16MoEMethod, + HPUCompressedTensorsW8A8Fp8MoEMethod) from vllm_gaudi.utils import HPUCompileConfig from vllm.forward_context import override_forward_context from safetensors import safe_open @@ -392,3 +393,173 @@ def test_compressed_tensors_wna16_moe_method(default_vllm_config: None, dist_ini # Check correctness torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4) + + +def test_compressed_tensors_linear_method_w8a8fp8_block(default_vllm_config: None, dist_init): + """weight per-block, activation dynamic per-group + Config based on mistralai/Mistral-Large-3-675B-Instruct-2512 params.json + """ + block_structure = [128, 128] + config = { + 'config_groups': { + 'FP8_BLOCK': { + 'format': 'float-quantized', + 'input_activations': { + 'actorder': None, + 'block_structure': None, + 'dynamic': True, + 'group_size': 128, + 'num_bits': 8, + 'observer': None, + 'observer_kwargs': {}, + 'strategy': 'group', + 'symmetric': True, + 'type': 'float' + }, + 'output_activations': None, + 'targets': ['Linear'], + 'weights': { + 'actorder': None, + 'block_structure': block_structure, + 'dynamic': False, + 'group_size': None, + 'num_bits': 8, + 'observer': 'static_minmax', + 'observer_kwargs': {}, + 'strategy': 'block', + 'symmetric': True, + 'type': 'float' + } + } + }, + 'format': 'float-quantized', + 'global_compression_ratio': None, + 'ignore': [], + 'kv_cache_scheme': None, + 'quant_method': 'compressed-tensors', + 'quantization_status': 'compressed' + } + oot_quant_config = CompressedTensorsConfig.from_config(config) + input_size = 256 + output_size = 256 + block_n, block_k = block_structure + + oot_op = create_row_parallel_linear(input_size=input_size, output_size=output_size, + quant_config=oot_quant_config).to("hpu") + assert isinstance(oot_op.quant_method, HPUCompressedTensorsLinearMethod) + assert isinstance(oot_op.scheme, HPUCompressedTensorsW8A8Fp8) + + # Create synthetic FP8 block-quantized weights + weight_fp32 = torch.randn(output_size, input_size, dtype=torch.bfloat16, device="hpu") + weight_fp8 = weight_fp32.to(torch.float8_e4m3fn) + scale_rows = (output_size + block_n - 1) // block_n + scale_cols = (input_size + block_k - 1) // block_k + weight_scale = torch.ones(scale_rows, scale_cols, dtype=torch.float32, device="hpu") + oot_op.weight.data.copy_(weight_fp8) + oot_op.weight_scale.data.copy_(weight_scale) + + oot_op.quant_method.process_weights_after_loading(oot_op) + + # Verify blockwise post-processing created the expected attributes + assert hasattr(oot_op, "weight_scale_inv"), "weight_scale_inv should be created for block strategy" + assert not hasattr(oot_op, "weight_scale"), "weight_scale should be removed after aliasing" + + # Execute layer with synthetic input + x = torch.randn(1, 4, input_size, dtype=torch.bfloat16, device="hpu") + out = oot_op.scheme.apply_weights(oot_op, x) + assert out.shape == (1, 4, output_size) + assert out.dtype == torch.bfloat16 + + +def test_compressed_tensors_w8a8fp8_block_moe_method(default_vllm_config: None, dist_init): + """FP8 block-quantized MoE: weight per-block, activation dynamic per-group + Config based on mistralai/Mistral-Large-3-675B-Instruct-2512 params.json + """ + block_structure = [128, 128] + config = { + 'config_groups': { + 'FP8_BLOCK': { + 'format': 'float-quantized', + 'input_activations': { + 'actorder': None, + 'block_structure': None, + 'dynamic': True, + 'group_size': 128, + 'num_bits': 8, + 'observer': None, + 'observer_kwargs': {}, + 'strategy': 'group', + 'symmetric': True, + 'type': 'float' + }, + 'output_activations': None, + 'targets': ['Linear'], + 'weights': { + 'actorder': None, + 'block_structure': block_structure, + 'dynamic': False, + 'group_size': None, + 'num_bits': 8, + 'observer': 'static_minmax', + 'observer_kwargs': {}, + 'strategy': 'block', + 'symmetric': True, + 'type': 'float' + } + } + }, + 'format': 'float-quantized', + 'global_compression_ratio': None, + 'ignore': [], + 'kv_cache_scheme': None, + 'quant_method': 'compressed-tensors', + 'quantization_status': 'compressed' + } + oot_quant_config = CompressedTensorsConfig.from_config(config) + + oot_op = create_fused_moe(oot_quant_config).to("hpu") + assert isinstance(oot_op.quant_method, HPUCompressedTensorsW8A8Fp8MoEMethod) + + num_experts = 128 + hidden_size = 512 + intermediate_size = 256 + block_n, block_k = block_structure + + # Create synthetic FP8 block-quantized MoE weights + w13_weight = torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=torch.bfloat16, + device="hpu").to(torch.float8_e4m3fn) + w2_weight = torch.randn(num_experts, hidden_size, intermediate_size, dtype=torch.bfloat16, + device="hpu").to(torch.float8_e4m3fn) + + w13_scale_rows = (2 * intermediate_size + block_n - 1) // block_n + w13_scale_cols = (hidden_size + block_k - 1) // block_k + w2_scale_rows = (hidden_size + block_n - 1) // block_n + w2_scale_cols = (intermediate_size + block_k - 1) // block_k + + w13_weight_scale = torch.ones(num_experts, w13_scale_rows, w13_scale_cols, dtype=torch.float32, device="hpu") + w2_weight_scale = torch.ones(num_experts, w2_scale_rows, w2_scale_cols, dtype=torch.float32, device="hpu") + + oot_op.w13_weight.data.copy_(w13_weight) + oot_op.w2_weight.data.copy_(w2_weight) + oot_op.w13_weight_scale.data.copy_(w13_weight_scale) + oot_op.w2_weight_scale.data.copy_(w2_weight_scale) + + oot_op.quant_method.process_weights_after_loading(oot_op) + + # Verify blockwise post-processing created the expected attributes + assert hasattr(oot_op, "w13_weight_scale_inv"), "w13_weight_scale_inv should be created for block MoE" + assert hasattr(oot_op, "w2_weight_scale_inv"), "w2_weight_scale_inv should be created for block MoE" + assert not hasattr(oot_op, "w13_weight_scale"), "w13_weight_scale should be removed after aliasing" + assert not hasattr(oot_op, "w2_weight_scale"), "w2_weight_scale should be removed after aliasing" + + # Execute layer with synthetic input + hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device="hpu") + router_logits = torch.randn(4, num_experts, dtype=torch.bfloat16, device="hpu") + + mock_ctx = MagicMock(spec=["dp_metadata"]) + mock_ctx.dp_metadata = None + with override_forward_context(mock_ctx): + out = oot_op.runner.forward_impl(oot_op, hidden_states, router_logits, hidden_states) + + assert out.shape == hidden_states.shape + assert out.dtype == torch.bfloat16 diff --git a/vllm_gaudi/attention/oot_mla.py b/vllm_gaudi/attention/oot_mla.py index 5f14c7a82d..66506521b1 100644 --- a/vllm_gaudi/attention/oot_mla.py +++ b/vllm_gaudi/attention/oot_mla.py @@ -173,14 +173,14 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # Channel-wise FP8 (produced by VLLM_HPU_FORCE_CHANNEL_FP8=True): # one scale per output channel; dequant via simple broadcast multiply. ws = weight_scale_inv.view(-1, 1).to(act_dtype) # [N_out, 1] - kv_b_proj_weight = (weight.to(act_dtype) * ws).T + kv_b_proj_weight = (weight.to(act_dtype) * ws).T.contiguous() else: # Block FP8 (force_channel_fp8=False): use HPU block dequant. from vllm_gaudi.extension.ops import dequant_block_fp8_weight_naive orig_M = kv_b_proj.orig_M.item() if hasattr(kv_b_proj, 'orig_M') else None orig_N = kv_b_proj.orig_N.item() if hasattr(kv_b_proj, 'orig_N') else None kv_b_proj_weight = dequant_block_fp8_weight_naive( - weight, + weight.contiguous(), weight_scale_inv, kv_b_proj.weight_block_size, dtype=act_dtype, diff --git a/vllm_gaudi/ops/hpu_compressed_tensors.py b/vllm_gaudi/ops/hpu_compressed_tensors.py index dfc717c7ed..b16913cc0e 100644 --- a/vllm_gaudi/ops/hpu_compressed_tensors.py +++ b/vllm_gaudi/ops/hpu_compressed_tensors.py @@ -6,6 +6,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import WEIGHT_LOADER_V2_SUPPORTED from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEConfig) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import convert_to_channelwise, all_close_1d @@ -54,6 +55,24 @@ SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR, QuantizationStrategy.BLOCK] +def _hpu_weight_scale_alias(layer: torch.nn.Module, scale_name: str, hpu_scale_name: str) -> None: + """Rename weight scale name to convention expected by HPU ops + + For example, for block quantization, HPU ops expect `weight_scale` to be named `weight_scale_inv`. + + This preserves compatibility across checkpoints/codepaths that use either + naming convention. + """ + if hasattr(layer, hpu_scale_name) or not hasattr(layer, scale_name): + return + + # Rename weight_scale to convention expected by HPU ops + scale = getattr(layer, scale_name) + scale = scale.data if isinstance(scale, torch.nn.Parameter) else scale + layer.register_parameter(hpu_scale_name, torch.nn.Parameter(scale, requires_grad=False)) + delattr(layer, scale_name) + + @CustomOp.register_oot(name='CompressedTensorsLinearMethod') class HPUCompressedTensorsLinearMethod(OrigCompressedTensorsLinearMethod): @@ -111,14 +130,15 @@ def get_hpu_scheme(self, layer: torch.nn.Module): def dequant_fp8_weight(self, layer: torch.nn.Module) -> torch.Tensor: if layer.scheme.strategy == QuantizationStrategy.CHANNEL: # weights were quantized per-channel - dequant_weight = layer.weight.to(layer.weight_scale.dtype) * layer.weight_scale.squeeze() + weight_scale = layer.weight_scale_inv if hasattr(layer, "weight_scale_inv") else layer.weight_scale + dequant_weight = layer.weight.to(weight_scale.dtype) * weight_scale.squeeze() return dequant_weight.to(torch.bfloat16).t() elif layer.scheme.strategy == QuantizationStrategy.BLOCK: if hasattr(layer, "updated_fp8_weight") and layer.updated_fp8_weight: return layer.weight dequant_weight = hpu_ops.dequant_block_fp8_weight_naive( - layer.weight.t(), - layer.weight_scale.data, + layer.weight, + layer.weight_scale_inv.data, layer.weight_block_size, original_M=layer.orig_M, original_N=layer.orig_N, @@ -153,7 +173,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ws_channelwise = convert_to_channelwise(layer.weight_scale, layer.logical_widths) layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False) elif layer.scheme.strategy == QuantizationStrategy.BLOCK: + # Rename blockwise quantization scales to match fp8_block_linear_postprocess_weights + # Needed for models like Mistral-Large-3-675B + assert self.is_static_input_scheme is False + _hpu_weight_scale_alias(layer, "weight_scale", "weight_scale_inv") + layer.quant_config.weight_block_size = self.weight_block_size layer = hpu_ops.fp8_block_linear_postprocess_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8) + return else: # required by torch.compile to be torch.nn.Parameter layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) @@ -248,14 +274,26 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, layer.register_parameter("input_scale", input_scale) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None): - weight_scale = layer.weight_scale.transpose(0, 1) if layer.weight_scale.dim() > 1 else layer.weight_scale + if self.weight_block_size is not None and layer.weight.dtype == torch.float8_e4m3fn: + return hpu_ops.apply_block_fp8_linear_hpu( + input=x, + layer=layer, + block_size=self.weight_block_size, + bias=bias, + do_unpad=True, + force_channel_fp8=envs.VLLM_HPU_FORCE_CHANNEL_FP8, + ) + weight_scale = layer.weight_scale_inv if hasattr(layer, "weight_scale_inv") else layer.weight_scale + weight_scale = weight_scale.transpose(0, 1) if weight_scale.dim() > 1 else weight_scale input_scale = getattr(layer, 'input_scale', None) - return hpu_ops.apply_fp8_linear_hpu(input=x, - weight=layer.weight, - weight_scale=weight_scale, - input_scale=input_scale, - bias=bias, - trans_B=False) + input_2d = x.view(-1, x.shape[-1]) + output = hpu_ops.apply_fp8_linear_hpu(input=input_2d, + weight=layer.weight, + weight_scale=weight_scale, + input_scale=input_scale, + bias=bias, + trans_B=False) + return output.view(*x.shape[:-1], -1) @CustomOp.register_oot(name='CompressedTensorsW8A8Fp8MoEMethod') @@ -375,11 +413,39 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.block_quant: assert layer.weight_block_size is not None + # Rename blockwise quantization scales to match fp8_block_moe_prepare_weights + # Needed for models like Mistral-Large-3-675B + _hpu_weight_scale_alias(layer, "w13_weight_scale", "w13_weight_scale_inv") + _hpu_weight_scale_alias(layer, "w2_weight_scale", "w2_weight_scale_inv") + layer.quant_config.weight_block_size = self.weight_block_size layer = hpu_ops.fp8_block_moe_prepare_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8) else: layer = hpu_ops.fp8_channel_moe_prepare_weights(layer) return + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + # On HPU, block quantization renames w13_weight_scale/w2_weight_scale + # to w13_weight_scale_inv/w2_weight_scale_inv via _hpu_weight_scale_alias. + # Use the renamed attributes for block quant. + if self.block_quant: + w13_scale = layer.w13_weight_scale_inv + w2_scale = layer.w2_weight_scale_inv + else: + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + return FusedMoEQuantConfig.make( + quant_dtype=self.fp8_backend, + w1_scale=w13_scale, + w2_scale=w2_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=is_per_token, + per_out_ch_quant=is_per_token, + block_shape=self.weight_block_size, + ) + def apply_monolithic( self, layer: FusedMoE,