Skip to content

Commit 980929e

Browse files
authored
[https://nvbugs/5410687][fix] Hopper w4a8 groupwise MoE interleave (#6708)
Signed-off-by: Haohang Huang <[email protected]>
1 parent db8dc97 commit 980929e

File tree

3 files changed

+125
-59
lines changed

3 files changed

+125
-59
lines changed

tensorrt_llm/layers/moe.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from ..parameter import Parameter
4141
from ..plugin import TRT_LLM_PLUGIN_NAMESPACE
4242
from ..quantization import GroupwiseQuantAlgo, QuantMode
43-
from ..quantization.functional import (postprocess_weight_only,
43+
from ..quantization.functional import (get_weight_scale_interleave_factor,
44+
postprocess_weight_only,
4445
preprocess_weights_for_mixed_gemm,
4546
quantize)
4647
from .linear import RowLinear
@@ -489,11 +490,18 @@ def __init__(self, in_features: int, out_features: int,
489490
self.alpha = Parameter(shape=(experts_per_node, ),
490491
dtype=trt.float32)
491492
elif quant_mode.has_per_group_scaling():
492-
self.weight = Parameter(shape=(experts_per_node, in_features,
493-
out_features // 4),
494-
dtype=dtype)
495-
scale_shape = (experts_per_node, in_features // group_size,
496-
out_features)
493+
self.weight = Parameter(
494+
shape=(experts_per_node, in_features,
495+
out_features // 4), # int4 <--> fp16/bf16
496+
dtype=dtype)
497+
if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA:
498+
scale_interleave_factor = get_weight_scale_interleave_factor(
499+
in_features, group_size)
500+
else:
501+
scale_interleave_factor = 1
502+
scale_shape = (experts_per_node,
503+
in_features // group_size // scale_interleave_factor,
504+
out_features * scale_interleave_factor)
497505
self.weights_scaling_factor = Parameter(shape=scale_shape,
498506
dtype=dtype)
499507
if groupwise_quant_algo & GroupwiseQuantAlgo.ZERO:

tensorrt_llm/quantization/functional.py

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -950,10 +950,12 @@ def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode):
950950
return qweight, scale
951951

952952

953-
def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
954-
quant_mode: torch.dtype,
955-
act_dtype: torch.dtype,
956-
sm_: int = -1) -> torch.Tensor:
953+
def preprocess_weights_for_mixed_gemm(
954+
tensor: torch.Tensor,
955+
quant_mode: torch.dtype,
956+
act_dtype: torch.dtype,
957+
sm_: int = -1,
958+
do_weight_interleave: bool = True) -> torch.Tensor:
957959
sm_ = sm_ if sm_ > 0 else get_sm_version()
958960
if len(tensor.shape) == 2:
959961
tensor = tensor.unsqueeze(0)
@@ -988,13 +990,12 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
988990
assert (num_rows % B_ROWS_PER_MMA == 0)
989991
assert (num_cols % MMA_SHAPE_N == 0)
990992

991-
row_idx_list = [
992-
(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
993-
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][row_idx %
994-
B_ROWS_PER_MMA]
995-
for row_idx in range(num_rows)
996-
]
997-
tensor = tensor[:, row_idx_list, :]
993+
if do_weight_interleave:
994+
row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
995+
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][
996+
row_idx % B_ROWS_PER_MMA]
997+
for row_idx in range(num_rows)]
998+
tensor = tensor[:, row_idx_list, :]
998999

9991000
# subbyte_transpose
10001001
original_shape = tensor.shape
@@ -1010,42 +1011,63 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
10101011
else:
10111012
tensor = tensor.permute(0, 2, 1).reshape(original_shape)
10121013

1013-
# interleave_column_major_tensor
1014-
interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
1015-
if interleave > 1 and sm_ < 90:
1016-
rows_per_tile = 128 * 8 // BITS_PER_ELT_A
1017-
elts_in_int32 = 32 // BITS_PER_ELT_B
1018-
1019-
assert (num_rows % elts_in_int32 == 0)
1020-
assert (num_rows % rows_per_tile == 0)
1021-
1022-
tensor = tensor.reshape(num_experts, -1, interleave,
1023-
num_rows // rows_per_tile,
1024-
rows_per_tile * 4 // elts_in_int32)
1025-
tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape)
1026-
1027-
# add_bias_and_interleave_quantized_tensor_inplace
1028-
if BITS_PER_ELT_B == 8:
1029-
tensor += -256 * (tensor > 127).byte() + 128
1030-
tensor = tensor.reshape(-1, 4)[:, [0, 2, 1, 3]].reshape(tensor.shape)
1031-
elif BITS_PER_ELT_B == 4:
1032-
tensor = tensor.view(torch.uint8)
1033-
high_tensor = (tensor >> 4).unsqueeze(-1)
1034-
low_tensor = ((tensor << 4) >> 4).unsqueeze(-1)
1035-
new_tensor = torch.cat([low_tensor, high_tensor],
1036-
dim=-1).reshape(tensor.shape[0], tensor.shape[1],
1037-
-1)
1038-
new_tensor = new_tensor.reshape(
1039-
-1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape)
1040-
new_tensor += -16 * (new_tensor > 7).byte() + 8
1041-
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
1042-
tensor = new_tensor.view(torch.int8)
1043-
else:
1044-
raise NotImplementedError
1014+
if do_weight_interleave:
1015+
# interleave_column_major_tensor
1016+
interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
1017+
if interleave > 1 and sm_ < 90:
1018+
rows_per_tile = 128 * 8 // BITS_PER_ELT_A
1019+
elts_in_int32 = 32 // BITS_PER_ELT_B
1020+
1021+
assert (num_rows % elts_in_int32 == 0)
1022+
assert (num_rows % rows_per_tile == 0)
1023+
1024+
tensor = tensor.reshape(num_experts, -1, interleave,
1025+
num_rows // rows_per_tile,
1026+
rows_per_tile * 4 // elts_in_int32)
1027+
tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape)
1028+
1029+
# add_bias_and_interleave_quantized_tensor_inplace
1030+
if BITS_PER_ELT_B == 8:
1031+
tensor += -256 * (tensor > 127).byte() + 128
1032+
tensor = tensor.reshape(-1, 4)[:,
1033+
[0, 2, 1, 3]].reshape(tensor.shape)
1034+
elif BITS_PER_ELT_B == 4:
1035+
tensor = tensor.view(torch.uint8)
1036+
high_tensor = (tensor >> 4).unsqueeze(-1)
1037+
low_tensor = ((tensor << 4) >> 4).unsqueeze(-1)
1038+
new_tensor = torch.cat([low_tensor, high_tensor],
1039+
dim=-1).reshape(tensor.shape[0],
1040+
tensor.shape[1], -1)
1041+
new_tensor = new_tensor.reshape(
1042+
-1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape)
1043+
new_tensor += -16 * (new_tensor > 7).byte() + 8
1044+
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
1045+
tensor = new_tensor.view(torch.int8)
1046+
else:
1047+
raise NotImplementedError
10451048

10461049
return tensor.squeeze(0).contiguous()
10471050

10481051

1052+
def get_weight_scale_interleave_factor(interleaved_dim: int,
1053+
group_size: int = 128) -> int:
1054+
# Calculate the weight_scale interleave factor for W4A8 groupwise MoE quant
1055+
# only Hopper w4a8 does interleave for weight scale, other arch or Hopper w4a16 default to 1
1056+
factor = 1
1057+
if get_sm_version() == 90:
1058+
if interleaved_dim % (4 * group_size) == 0:
1059+
factor = 4
1060+
elif interleaved_dim % (2 * group_size) == 0:
1061+
factor = 2
1062+
elif interleaved_dim % group_size == 0:
1063+
factor = 1
1064+
else:
1065+
raise NotImplementedError(
1066+
f"Interleaved dimension must be a multiple of group_size ({group_size}), received {interleaved_dim}."
1067+
)
1068+
return factor
1069+
1070+
10491071
def validate_group_size(layer):
10501072
# TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented.
10511073
W4A8_AWQ = 8

tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,22 @@
1414
# limitations under the License.
1515
import unittest
1616

17+
import pytest
18+
1719
# isort: off
1820
import torch
1921
# isort: on
2022

2123
from parameterized import parameterized
22-
from utils.util import (create_session, run_session, skip_non_ada_unittest,
24+
from utils.util import (create_session, run_session,
25+
skip_neither_ada_nor_hopper_unittest,
2326
unittest_name_func)
2427

2528
import tensorrt_llm
2629
import tensorrt_llm.quantization.functional
2730
from tensorrt_llm import Tensor
28-
from tensorrt_llm._utils import (str_dtype_to_trt, torch_to_numpy,
29-
trt_dtype_to_str)
31+
from tensorrt_llm._utils import (get_sm_version, str_dtype_to_trt,
32+
torch_to_numpy, trt_dtype_to_str)
3033
from tensorrt_llm.layers.moe import MoeConfig
3134
from tensorrt_llm.quantization import QuantMode
3235

@@ -66,7 +69,8 @@ def create_trt_session(
6669
norm_mode = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
6770
quant_mode = QuantMode.use_weight_only(True, True)
6871
k = act.shape[1]
69-
n = weight_scaling_factor_1.shape[-1] // 2
72+
n = fc2_prequant_scale.shape[
73+
-1] # get the original n from prequant scale because either weight or scale could be interleaved
7074
num_experts = weight_scaling_factor_1.shape[0]
7175

7276
with tensorrt_llm.net_guard(network):
@@ -202,18 +206,50 @@ def _woq_moe_groupwise_matmul(self,
202206
ref_weight_1 += zero_1.repeat_interleave(group_size, dim=1)
203207
ref_weight_2 += zero_2.repeat_interleave(group_size, dim=1)
204208
activation_type = torch.float8_e4m3fn if has_alpha else activation_dtype
209+
do_weight_interleave = get_sm_version(
210+
) != 90 or not has_alpha # Hopper w4a8 does not interleave weight
205211
cuda_q_weight_1 = preprocessor(
206-
unprocessed_weight_1.cpu(), quantized_weight_dtype,
207-
activation_type).view(activation_dtype).cpu()
212+
unprocessed_weight_1.cpu(),
213+
quantized_weight_dtype,
214+
activation_type,
215+
do_weight_interleave=do_weight_interleave).view(
216+
activation_dtype).cpu()
208217
cuda_q_weight_2 = preprocessor(
209-
unprocessed_weight_2.cpu(), quantized_weight_dtype,
210-
activation_type).view(activation_dtype).cpu()
211-
if has_alpha and activation_dtype == torch.bfloat16:
218+
unprocessed_weight_2.cpu(),
219+
quantized_weight_dtype,
220+
activation_type,
221+
do_weight_interleave=do_weight_interleave).view(
222+
activation_dtype).cpu()
223+
if get_sm_version() == 89 and has_alpha:
212224
scale_1 = scale_1.to(torch.float16).view(activation_dtype)
213225
scale_2 = scale_2.to(torch.float16).view(activation_dtype)
214226
zero_1 = zero_1.to(torch.float16).view(activation_dtype)
215227
zero_2 = zero_2.to(torch.float16).view(activation_dtype)
216228

229+
if get_sm_version() == 90 and has_alpha:
230+
if has_zero:
231+
pytest.skip(
232+
"has_zero is not supported in Hopper with WINT4AFP8.")
233+
234+
def interleave_scales(scales: torch.Tensor, interleave_dim: int):
235+
# [num_experts, num_groups, num_cols] --> [num_experts, num_groups // interleave, num_cols * interleave]
236+
# Note: num_groups = num_rows // group_size
237+
E, G, C = scales.shape
238+
I = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
239+
interleave_dim, group_size)
240+
assert G % I == 0, f"Group dimension ({G}) must be divisible by interleave factor ({I})."
241+
scales_interleaved = scales.reshape(E, G // I, I, C)
242+
scales_interleaved = scales_interleaved.permute(0, 1, 3, 2)
243+
scales_interleaved = scales_interleaved.reshape(
244+
E, G // I, C * I)
245+
return scales_interleaved.contiguous()
246+
247+
scale_1 = scale_1.to(torch.bfloat16).view(activation_dtype)
248+
scale_2 = scale_2.to(torch.bfloat16).view(activation_dtype)
249+
scale_1 = interleave_scales(scale_1, k)
250+
scale_2 = interleave_scales(scale_2, n)
251+
zero_1, zero_2 = None, None
252+
217253
session = self.create_trt_session(
218254
activation_dtype_str, activation, router, pre_quant_scale_1,
219255
pre_quant_scale_2, cuda_q_weight_1, cuda_q_weight_2, scale_1,
@@ -278,7 +314,7 @@ def test_moe_w4a16(self, m, n, k, experts, dtype, has_pre_quant, has_zero):
278314
(1, 14336, 4096, 8, "bfloat16", True, False),
279315
(1, 14336, 4096, 8, "bfloat16", True, True)],
280316
name_func=unittest_name_func)
281-
@skip_non_ada_unittest
317+
@skip_neither_ada_nor_hopper_unittest
282318
def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero):
283319

284320
self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2,

0 commit comments

Comments
 (0)