Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions tensorrt_llm/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from ..parameter import Parameter
from ..plugin import TRT_LLM_PLUGIN_NAMESPACE
from ..quantization import GroupwiseQuantAlgo, QuantMode
from ..quantization.functional import (postprocess_weight_only,
from ..quantization.functional import (get_weight_scale_interleave_factor,
postprocess_weight_only,
preprocess_weights_for_mixed_gemm,
quantize)
from .linear import RowLinear
Expand Down Expand Up @@ -489,11 +490,18 @@ def __init__(self, in_features: int, out_features: int,
self.alpha = Parameter(shape=(experts_per_node, ),
dtype=trt.float32)
elif quant_mode.has_per_group_scaling():
self.weight = Parameter(shape=(experts_per_node, in_features,
out_features // 4),
dtype=dtype)
scale_shape = (experts_per_node, in_features // group_size,
out_features)
self.weight = Parameter(
shape=(experts_per_node, in_features,
out_features // 4), # int4 <--> fp16/bf16
dtype=dtype)
if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA:
scale_interleave_factor = get_weight_scale_interleave_factor(
in_features, group_size)
else:
scale_interleave_factor = 1
scale_shape = (experts_per_node,
in_features // group_size // scale_interleave_factor,
out_features * scale_interleave_factor)
self.weights_scaling_factor = Parameter(shape=scale_shape,
dtype=dtype)
if groupwise_quant_algo & GroupwiseQuantAlgo.ZERO:
Expand Down
108 changes: 65 additions & 43 deletions tensorrt_llm/quantization/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,10 +950,12 @@ def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode):
return qweight, scale


def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
quant_mode: torch.dtype,
act_dtype: torch.dtype,
sm_: int = -1) -> torch.Tensor:
def preprocess_weights_for_mixed_gemm(
tensor: torch.Tensor,
quant_mode: torch.dtype,
act_dtype: torch.dtype,
sm_: int = -1,
do_weight_interleave: bool = True) -> torch.Tensor:
sm_ = sm_ if sm_ > 0 else get_sm_version()
if len(tensor.shape) == 2:
tensor = tensor.unsqueeze(0)
Expand Down Expand Up @@ -988,13 +990,12 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
assert (num_rows % B_ROWS_PER_MMA == 0)
assert (num_cols % MMA_SHAPE_N == 0)

row_idx_list = [
(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][row_idx %
B_ROWS_PER_MMA]
for row_idx in range(num_rows)
]
tensor = tensor[:, row_idx_list, :]
if do_weight_interleave:
row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][
row_idx % B_ROWS_PER_MMA]
for row_idx in range(num_rows)]
tensor = tensor[:, row_idx_list, :]

# subbyte_transpose
original_shape = tensor.shape
Expand All @@ -1010,42 +1011,63 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
else:
tensor = tensor.permute(0, 2, 1).reshape(original_shape)

# interleave_column_major_tensor
interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
if interleave > 1 and sm_ < 90:
rows_per_tile = 128 * 8 // BITS_PER_ELT_A
elts_in_int32 = 32 // BITS_PER_ELT_B

assert (num_rows % elts_in_int32 == 0)
assert (num_rows % rows_per_tile == 0)

tensor = tensor.reshape(num_experts, -1, interleave,
num_rows // rows_per_tile,
rows_per_tile * 4 // elts_in_int32)
tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape)

# add_bias_and_interleave_quantized_tensor_inplace
if BITS_PER_ELT_B == 8:
tensor += -256 * (tensor > 127).byte() + 128
tensor = tensor.reshape(-1, 4)[:, [0, 2, 1, 3]].reshape(tensor.shape)
elif BITS_PER_ELT_B == 4:
tensor = tensor.view(torch.uint8)
high_tensor = (tensor >> 4).unsqueeze(-1)
low_tensor = ((tensor << 4) >> 4).unsqueeze(-1)
new_tensor = torch.cat([low_tensor, high_tensor],
dim=-1).reshape(tensor.shape[0], tensor.shape[1],
-1)
new_tensor = new_tensor.reshape(
-1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape)
new_tensor += -16 * (new_tensor > 7).byte() + 8
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
tensor = new_tensor.view(torch.int8)
else:
raise NotImplementedError
if do_weight_interleave:
# interleave_column_major_tensor
interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
if interleave > 1 and sm_ < 90:
rows_per_tile = 128 * 8 // BITS_PER_ELT_A
elts_in_int32 = 32 // BITS_PER_ELT_B

assert (num_rows % elts_in_int32 == 0)
assert (num_rows % rows_per_tile == 0)

tensor = tensor.reshape(num_experts, -1, interleave,
num_rows // rows_per_tile,
rows_per_tile * 4 // elts_in_int32)
tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape)

# add_bias_and_interleave_quantized_tensor_inplace
if BITS_PER_ELT_B == 8:
tensor += -256 * (tensor > 127).byte() + 128
tensor = tensor.reshape(-1, 4)[:,
[0, 2, 1, 3]].reshape(tensor.shape)
elif BITS_PER_ELT_B == 4:
tensor = tensor.view(torch.uint8)
high_tensor = (tensor >> 4).unsqueeze(-1)
low_tensor = ((tensor << 4) >> 4).unsqueeze(-1)
new_tensor = torch.cat([low_tensor, high_tensor],
dim=-1).reshape(tensor.shape[0],
tensor.shape[1], -1)
new_tensor = new_tensor.reshape(
-1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape)
new_tensor += -16 * (new_tensor > 7).byte() + 8
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
tensor = new_tensor.view(torch.int8)
else:
raise NotImplementedError

return tensor.squeeze(0).contiguous()


def get_weight_scale_interleave_factor(interleaved_dim: int,
group_size: int = 128) -> int:
# Calculate the weight_scale interleave factor for W4A8 groupwise MoE quant
# only Hopper w4a8 does interleave for weight scale, other arch or Hopper w4a16 default to 1
factor = 1
if get_sm_version() == 90:
if interleaved_dim % (4 * group_size) == 0:
factor = 4
elif interleaved_dim % (2 * group_size) == 0:
factor = 2
elif interleaved_dim % group_size == 0:
factor = 1
else:
raise NotImplementedError(
f"Interleaved dimension must be a multiple of group_size ({group_size}), received {interleaved_dim}."
)
return factor


def validate_group_size(layer):
# TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented.
W4A8_AWQ = 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@
# limitations under the License.
import unittest

import pytest

# isort: off
import torch
# isort: on

from parameterized import parameterized
from utils.util import (create_session, run_session, skip_non_ada_unittest,
from utils.util import (create_session, run_session,
skip_neither_ada_nor_hopper_unittest,
unittest_name_func)

import tensorrt_llm
import tensorrt_llm.quantization.functional
from tensorrt_llm import Tensor
from tensorrt_llm._utils import (str_dtype_to_trt, torch_to_numpy,
trt_dtype_to_str)
from tensorrt_llm._utils import (get_sm_version, str_dtype_to_trt,
torch_to_numpy, trt_dtype_to_str)
from tensorrt_llm.layers.moe import MoeConfig
from tensorrt_llm.quantization import QuantMode

Expand Down Expand Up @@ -66,7 +69,8 @@ def create_trt_session(
norm_mode = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
quant_mode = QuantMode.use_weight_only(True, True)
k = act.shape[1]
n = weight_scaling_factor_1.shape[-1] // 2
n = fc2_prequant_scale.shape[
-1] # get the original n from prequant scale because either weight or scale could be interleaved
num_experts = weight_scaling_factor_1.shape[0]

with tensorrt_llm.net_guard(network):
Expand Down Expand Up @@ -202,18 +206,50 @@ def _woq_moe_groupwise_matmul(self,
ref_weight_1 += zero_1.repeat_interleave(group_size, dim=1)
ref_weight_2 += zero_2.repeat_interleave(group_size, dim=1)
activation_type = torch.float8_e4m3fn if has_alpha else activation_dtype
do_weight_interleave = get_sm_version(
) != 90 or not has_alpha # Hopper w4a8 does not interleave weight
cuda_q_weight_1 = preprocessor(
unprocessed_weight_1.cpu(), quantized_weight_dtype,
activation_type).view(activation_dtype).cpu()
unprocessed_weight_1.cpu(),
quantized_weight_dtype,
activation_type,
do_weight_interleave=do_weight_interleave).view(
activation_dtype).cpu()
cuda_q_weight_2 = preprocessor(
unprocessed_weight_2.cpu(), quantized_weight_dtype,
activation_type).view(activation_dtype).cpu()
if has_alpha and activation_dtype == torch.bfloat16:
unprocessed_weight_2.cpu(),
quantized_weight_dtype,
activation_type,
do_weight_interleave=do_weight_interleave).view(
activation_dtype).cpu()
if get_sm_version() == 89 and has_alpha:
scale_1 = scale_1.to(torch.float16).view(activation_dtype)
scale_2 = scale_2.to(torch.float16).view(activation_dtype)
zero_1 = zero_1.to(torch.float16).view(activation_dtype)
zero_2 = zero_2.to(torch.float16).view(activation_dtype)

if get_sm_version() == 90 and has_alpha:
if has_zero:
pytest.skip(
"has_zero is not supported in Hopper with WINT4AFP8.")

def interleave_scales(scales: torch.Tensor, interleave_dim: int):
# [num_experts, num_groups, num_cols] --> [num_experts, num_groups // interleave, num_cols * interleave]
# Note: num_groups = num_rows // group_size
E, G, C = scales.shape
I = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
interleave_dim, group_size)
assert G % I == 0, f"Group dimension ({G}) must be divisible by interleave factor ({I})."
scales_interleaved = scales.reshape(E, G // I, I, C)
scales_interleaved = scales_interleaved.permute(0, 1, 3, 2)
scales_interleaved = scales_interleaved.reshape(
E, G // I, C * I)
return scales_interleaved.contiguous()

scale_1 = scale_1.to(torch.bfloat16).view(activation_dtype)
scale_2 = scale_2.to(torch.bfloat16).view(activation_dtype)
scale_1 = interleave_scales(scale_1, k)
scale_2 = interleave_scales(scale_2, n)
zero_1, zero_2 = None, None

session = self.create_trt_session(
activation_dtype_str, activation, router, pre_quant_scale_1,
pre_quant_scale_2, cuda_q_weight_1, cuda_q_weight_2, scale_1,
Expand Down Expand Up @@ -278,7 +314,7 @@ def test_moe_w4a16(self, m, n, k, experts, dtype, has_pre_quant, has_zero):
(1, 14336, 4096, 8, "bfloat16", True, False),
(1, 14336, 4096, 8, "bfloat16", True, True)],
name_func=unittest_name_func)
@skip_non_ada_unittest
@skip_neither_ada_nor_hopper_unittest
def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero):

self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2,
Expand Down