diff --git a/vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py b/vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py index 3dfe06f1b130..d0515027628e 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py @@ -42,12 +42,13 @@ def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: not in [ torch.float32, torch.bfloat16, + torch.float16, ] ): return ( False, "Dynamic4bitLinearKernel on Arm requires Float32 or" - " BFloat16 activations", + " BFloat16 or Float16 activations", ) if c.full_weight_shape[0] % c.group_size != 0: return ( @@ -118,8 +119,30 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + # PyTorch / KleidiAI kernels natively support the following configs: + # - channelwise with bfloat16 / float32 activations + # - groupwise with float32 activations + # To support: + # - groupwise with bfloat16/float16 activations: we need to upcast + # activations to float32 before matmul and downcast back to bfloat16/float16 + # - channelwise with float16 activations, we need to upcast activations to + # float32 before matmul and downcast back to float16 + # Note: these activations will be dynamically quantized to int8 by the kernel. + c = self.config + is_groupwise = c.group_size != c.partition_weight_shape[0] + # dtype of activations before they get dynamically quantized to int8 + original_pre_quant_act_dtype = x.dtype + pre_quant_act_dtype = original_pre_quant_act_dtype + if ( + is_groupwise and pre_quant_act_dtype == torch.bfloat16 + ) or pre_quant_act_dtype == torch.float16: + pre_quant_act_dtype = torch.float32 + x_2d = x.reshape(-1, x.shape[-1]) + if pre_quant_act_dtype != original_pre_quant_act_dtype: + x_2d = x_2d.to(pre_quant_act_dtype) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) w_q = getattr(layer, self.w_q_name) @@ -129,5 +152,8 @@ def apply_weights( c.group_size, c.partition_weight_shape[0], c.partition_weight_shape[1], - ) - return output.reshape(out_shape) + ).reshape(out_shape) + + if pre_quant_act_dtype != original_pre_quant_act_dtype: + output = output.to(original_pre_quant_act_dtype) + return output