Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
not in [
torch.float32,
torch.bfloat16,
torch.float16,
]
):
return (
False,
"Dynamic4bitLinearKernel on Arm requires Float32 or"
" BFloat16 activations",
" BFloat16 or Float16 activations",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
Expand Down Expand Up @@ -118,8 +119,30 @@ def apply_weights(
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# PyTorch / KleidiAI kernels natively support the following configs:
# - channelwise with bfloat16 / float32 activations
# - groupwise with float32 activations
# To support:
# - groupwise with bfloat16/float16 activations: we need to upcast
# activations to float32 before matmul and downcast back to bfloat16/float16
# - channelwise with float16 activations, we need to upcast activations to
# float32 before matmul and downcast back to float16
# Note: these activations will be dynamically quantized to int8 by the kernel.

c = self.config
is_groupwise = c.group_size != c.partition_weight_shape[0]
# dtype of activations before they get dynamically quantized to int8
original_pre_quant_act_dtype = x.dtype
pre_quant_act_dtype = original_pre_quant_act_dtype
if (
is_groupwise and pre_quant_act_dtype == torch.bfloat16
) or pre_quant_act_dtype == torch.float16:
pre_quant_act_dtype = torch.float32

x_2d = x.reshape(-1, x.shape[-1])
if pre_quant_act_dtype != original_pre_quant_act_dtype:
x_2d = x_2d.to(pre_quant_act_dtype)

out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)

w_q = getattr(layer, self.w_q_name)
Expand All @@ -129,5 +152,8 @@ def apply_weights(
c.group_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
return output.reshape(out_shape)
).reshape(out_shape)

if pre_quant_act_dtype != original_pre_quant_act_dtype:
output = output.to(original_pre_quant_act_dtype)
return output