From 2f3ed3b747f1ad5fc69b643c4370db61a7797129 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 18:25:54 +0800 Subject: [PATCH 01/22] add moe_quant_int quantization method Signed-off-by: Jinzhen Lin --- tests/kernels/test_moe.py | 96 ++++- .../layers/fused_moe/fused_moe.py | 402 +++++++++++++++--- .../layers/quantization/__init__.py | 7 +- .../layers/quantization/moe_quant_int.py | 362 ++++++++++++++++ 4 files changed, 809 insertions(+), 58 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/moe_quant_int.py diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7fa5de198445..bd1fb044c9d0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -14,10 +14,10 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -55,6 +55,98 @@ def test_fused_moe( rtol=0) +@pytest.mark.parametrize("m", [1, 32, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.parametrize("has_zp", [True, False]) +@pytest.mark.parametrize("weight_bits", [4, 8]) +def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, has_zp: bool, + weight_bits: int): + print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + if weight_bits == 4: + pack_factor = 2 + if has_zp: + quant_type = scalar_types.uint4 + else: + quant_type = scalar_types.uint4b8 + elif weight_bits == 8: + pack_factor = 1 + if has_zp: + quant_type = scalar_types.uint8 + else: + quant_type = scalar_types.uint8b128 + + w1_ref = w1.clone() + w2_ref = w2.clone() + w1_qweight = torch.empty((e, 2 * n, k // pack_factor), + device="cuda", + dtype=torch.uint8) + w2_qweight = torch.empty((e, k, n // pack_factor), + device="cuda", + dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), + device="cuda", + dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), + device="cuda", + dtype=dtype) + w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), + device="cuda", + dtype=torch.uint8) + w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), + device="cuda", + dtype=torch.uint8) + + for i in range(e * 2): + expert_id = i % e + if i // e == 0: + w, w_ref, w_qweight, w_scales, w_qzeros = \ + w1, w1_ref, w1_qweight, w1_scales, w1_qzeros + else: + w, w_ref, w_qweight, w_scales, w_qzeros = \ + w2, w2_ref, w2_qweight, w2_scales, w2_qzeros + weight, qweight, scales, qzeros = quantize_weights( + w[expert_id].T, quant_type, group_size, has_zp, False) + weight = weight.T + qweight = qweight.T.contiguous().to(torch.uint8) + scales = scales.T + qzeros = qzeros.T.contiguous().to(torch.uint8) + if weight_bits == 4: + qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] + + w_ref[expert_id] = weight + w_qweight[expert_id] = qweight + w_scales[expert_id] = scales + w_qzeros[expert_id] = qzeros + + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w8a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros, + w2_zp=w2_qzeros, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 308c1d6ac6db..5327b60fed35 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -19,6 +19,201 @@ logger = init_logger(__name__) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w8a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w8a16: + b_zp = 8 + if not has_zp and use_int8_w8a16: + b_zp = 128 + elif has_zp and use_int4_w8a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w8a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(compute_type) + + if has_zp and use_int4_w8a16: + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + (offs_bn[None, :] // 2) * stride_bzn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = ((b_zp >> b_zp_shifter) & 0xF).to(compute_type) + b_zp = b_zp.to(compute_type) + elif has_zp and use_int8_w8a16: + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + offs_bn[None, :] * stride_bzn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(compute_type) + + # We accumulate along the K dimension. + accumulator = tl.dot(a, (b.to(compute_type) - b_zp) * b_scale, + acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w8a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -257,7 +452,8 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + expert_ids, num_tokens_post_pad, + num_experts >= 256) return sorted_ids, expert_ids, num_tokens_post_pad @@ -266,6 +462,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -277,6 +474,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w8a16: bool, block_shape: Optional[List[int]] = None) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -292,50 +490,105 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16: + elif use_int8_w8a16 or use_int4_w8a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], + A.shape[0] * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( + B.shape[1], META['BLOCK_SIZE_N']), ) + + if (use_int8_w8a16 or use_int4_w8a16) and \ + block_shape is not None and block_shape[1] > 0: + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w8a16=use_int4_w8a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + else: + ndim = lambda x: 0 if x is None else x.ndim + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if ndim(A_scale) == 2 else 0, + A_scale.stride(1) if ndim(A_scale) == 2 else 0, + B_scale.stride(0) if ndim(A_scale) >= 2 else 0, + B_scale.stride(2) if ndim(A_scale) == 3 else 0, + B_scale.stride(1) if ndim(A_scale) >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @@ -432,7 +685,7 @@ def try_get_optimal_moe_config( # NOTE: For block-wise quant, # BLOCK_K must be divisible by block_shape[1] # BLOCK_N and BLOCK_M has no requirements - if block_shape is not None: + if block_shape is not None and block_shape[0] != 0: config["BLOCK_SIZE_N"] = block_shape[0] config["BLOCK_SIZE_K"] = block_shape[1] return config @@ -531,12 +784,15 @@ def grouped_topk(hidden_states: torch.Tensor, def get_config_dtype_str(dtype: torch.dtype, + use_int4_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w8a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -551,14 +807,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, - a1_scale, a2_scale, block_shape) + use_fp8_w8a8, use_int8_w8a16, use_int4_w8a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -569,8 +828,11 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: @@ -593,14 +855,18 @@ def outplace_fused_experts( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, use_fp8_w8a8, use_int8_w8a16, w1_scale, - w2_scale, a1_scale, a2_scale, block_shape) + False, use_fp8_w8a8, use_int8_w8a16, + use_int4_w8a16, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -611,8 +877,11 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: @@ -635,8 +904,11 @@ def fused_experts(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None): @@ -644,16 +916,15 @@ def fused_experts(hidden_states: torch.Tensor, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, use_int8_w8a16, - w1_scale, w2_scale, a1_scale, + use_int4_w8a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) return hidden_states else: - return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2, - topk_weights, topk_ids, - use_fp8_w8a8, - use_int8_w8a16, w1_scale, - w2_scale, a1_scale, - a2_scale, block_shape) + return torch.ops.vllm.outplace_fused_experts( + hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, + use_int8_w8a16, use_int4_w8a16, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -664,13 +935,21 @@ def fused_experts_impl(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w8a16: + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -687,6 +966,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, M = min(num_tokens, CHUNK_SIZE) config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w8a16=use_int4_w8a16, dtype=hidden_states.dtype) get_config_func = functools.partial( @@ -755,6 +1035,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -766,6 +1047,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w8a16=use_int4_w8a16, block_shape=block_shape) torch.ops._C.silu_and_mul(intermediate_cache2, @@ -776,6 +1058,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -787,6 +1070,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w8a16=use_int4_w8a16, block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), @@ -808,8 +1092,11 @@ def fused_moe( custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -834,8 +1121,12 @@ def fused_moe( note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w8a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -873,8 +1164,11 @@ def fused_moe( inplace=inplace, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w8a16=use_int4_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index caeb8b95e02f..6f1fdf83fbdb 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -26,7 +26,8 @@ "experts_int8", "neuron_quant", "ipex", - "quark" + "quark", + "moe_quant_int" ] @@ -58,6 +59,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .neuron_quant import NeuronQuantConfig from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig + from .moe_quant_int import MoeQuantIntConfig method_to_config: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -82,7 +84,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, - "quark": QuarkConfig + "quark": QuarkConfig, + "moe_quant_int": MoeQuantIntConfig, } return method_to_config[quantization] diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py new file mode 100644 index 000000000000..c1b62293c40e --- /dev/null +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -0,0 +1,362 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch + +from vllm.distributed import get_tp_group +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod) +from vllm.model_executor.utils import set_weight_attrs + + +class MoeQuantIntConfig(QuantizationConfig): + """Config class for Int8 experts quantization.""" + + def __init__(self, linear_quant_method: str, weight_bits: int, + group_size: int, has_zp: bool, lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.has_zp = has_zp + self.bit8_pack_factor = 8 // self.weight_bits + self.lm_head_quantized = lm_head_quantized + self.linear_quant_method = linear_quant_method + self.modules_to_not_convert = modules_to_not_convert + + if self.linear_quant_method == "gptq": + self.linear_quant_config = GPTQMarlinConfig.from_config( + full_config) + elif self.linear_quant_method == "awq": + self.linear_quant_config = AWQMarlinConfig.from_config(full_config) + else: + raise ValueError("MoeQuantInt only support gptq now.") + + @classmethod + def get_name(cls) -> str: + return "moe_quant_int" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MoeQuantIntConfig": + linear_quant_method = cls.get_from_keys(config, ["quant_method"]) + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + if linear_quant_method == "gptq": + has_zp = not cls.get_from_keys(config, ["sym"]) + modules_to_not_convert = [] + elif linear_quant_method == "awq": + has_zp = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys( + config, ["modules_to_not_convert"]) + else: + raise ValueError("moe_quant_int only support gptq and awq.") + + return cls(linear_quant_method, weight_bits, group_size, has_zp, + lm_head_quantized, modules_to_not_convert, config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_moe_quant_int_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None + or user_quant == "moe_quant_int") + if can_convert and is_valid_user_quant: + return cls.get_name() + return None + + @classmethod + def is_moe_quant_int_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + if quant_method == "gptq" and not desc_act and num_bits in [4, 8] and \ + GPTQMarlinConfig.is_gptq_marlin_compatible(quant_config): + return True + if quant_method == "awq" and num_bits == 4 and \ + AWQMarlinConfig.is_awq_marlin_compatible(quant_config): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + elif isinstance(layer, LinearBase): + if self.linear_quant_method == "gptq": + return GPTQMarlinLinearMethod(self.linear_quant_config) + elif self.linear_quant_method == "awq": + return AWQMarlinLinearMethod(self.linear_quant_config) + else: + raise ValueError("moe_quant_int only support gptq and awq.") + elif isinstance(layer, FusedMoE): + return MoeQuantIntMethod(self) + return None + + +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class MoeQuantIntMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: MoeQuantIntConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.quant_config = self.quant_config + bit8_pack_factor = self.quant_config.bit8_pack_factor + group_size = self.quant_config.group_size + group_size_div_factor = 1 + + # make intermediate_size and hidden_size diviable by group_size + # we reduce the group size to ensure that + # and we would repeat the loaded_weight later + while intermediate_size % group_size or hidden_size % group_size: + group_size = group_size // 2 + group_size_div_factor *= 2 + assert group_size >= 32 + layer.group_size = group_size + layer.group_size_div_factor = group_size_div_factor + + strategy = FusedMoeWeightScaleSupported.GROUP.value + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": False + }) + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = MoeQuantIntMethod.get_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size // + bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size // + bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + w13_scales = torch.nn.Parameter(torch.zeros(num_experts, + 2 * intermediate_size, + hidden_size // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + intermediate_size // + group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + if self.quant_config.has_zp: + w13_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + if self.quant_config.linear_quant_method == "gptq": + # some param are unused, but we need to init them in order to + # load weights + invalid_param_keys = ["w13_g_idx", "w2_g_idx"] + if not self.quant_config.has_zp: + invalid_param_keys += ["w13_qzeros", "w2_qzeros"] + for key in invalid_param_keys: + param = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int32), + requires_grad=False) + layer.register_parameter(key, param) + set_weight_attrs(param, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + + return fused_experts(x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w8a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size]) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def convert_awq_tensor(tensor, tensor_type): + size0 = tensor.size(0) + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor.view(-1, + 8)[:, + [0, 4, 1, 5, 2, 6, 3, 7]].view(size0, -1) + tensor = tensor.T.contiguous() + if tensor_type == "qweight": + tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] + elif tensor_type == "qzeros": + tensor = tensor[1::2, :] * 16 + tensor[::2, :] + return tensor + + def convert_gptq_int4_qzeros(tensor): + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor + 1 + tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 + return tensor + + def moe_quant_int_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int): + if "g_idx" in weight_name: + return + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return + + device = get_tp_group().device + tp_rank = get_tensor_model_parallel_rank() + loaded_weight = loaded_weight.to(device) + shard_size = layer.intermediate_size_per_partition + + # convert gptq and awq weight to a standard format + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + if "weight" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, + "qweight") + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") + else: + loaded_weight = loaded_weight.T + elif layer.quant_config.linear_quant_method == "gptq": + assert layer.quant_config.weight_bits in [4, 8] + if "weight" in weight_name: + loaded_weight = loaded_weight.T.contiguous().view( + torch.uint8) + elif "zeros" in weight_name: + loaded_weight = loaded_weight.view(torch.uint8) + if layer.quant_config.weight_bits == 4: + loaded_weight = convert_gptq_int4_qzeros( + loaded_weight).T + else: + loaded_weight = loaded_weight.T + 1 + else: + loaded_weight = loaded_weight.T + + # repeat the qzeros/scales to fit new group size + if layer.group_size_div_factor > 1 and "qzeros" in weight_name or "scales" in weight_name: + loaded_weight = loaded_weight.repeat_interleave( + layer.group_size_div_factor, 1) + + if "w13_qzeros" in weight_name: + tensor = loaded_weight.view(layer.tp_size, -1, + loaded_weight.size(1))[tp_rank] + if shard_id == "w1": + param.data[expert_id, :shard_size // 2] = tensor + else: + param.data[expert_id, shard_size // 2:] = tensor + elif "w2_qzeros" in weight_name: + param.data[expert_id] = loaded_weight.view( + loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + else: + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return moe_quant_int_weight_loader From 97f18efcd38709450b0d66bce79096043b648b54 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 19:20:50 +0800 Subject: [PATCH 02/22] use tl.float32 to dequantize Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5327b60fed35..c064b91284ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -171,25 +171,25 @@ def fused_moe_kernel_gptq_awq( offs_bn[None, :] * stride_bsn + \ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) - b_scale = b_scale.to(compute_type) + b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w8a16: b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ (offs_bn[None, :] // 2) * stride_bzn + \ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = ((b_zp >> b_zp_shifter) & 0xF).to(compute_type) - b_zp = b_zp.to(compute_type) + b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ offs_bn[None, :] * stride_bzn + \ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = b_zp.to(compute_type) + b_zp = b_zp.to(tl.float32) # We accumulate along the K dimension. - accumulator = tl.dot(a, (b.to(compute_type) - b_zp) * b_scale, - acc=accumulator) + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak From 053045275bf4d04827be66d6fe9ba135ea49ef72 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 19:31:14 +0800 Subject: [PATCH 03/22] fix format error Signed-off-by: Jinzhen Lin --- tests/kernels/test_moe.py | 12 ++++-------- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 ++++-- .../layers/quantization/moe_quant_int.py | 13 +++++++------ 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index bd1fb044c9d0..41aa6ec1601e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -75,16 +77,10 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, if weight_bits == 4: pack_factor = 2 - if has_zp: - quant_type = scalar_types.uint4 - else: - quant_type = scalar_types.uint4b8 + quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8 elif weight_bits == 8: pack_factor = 1 - if has_zp: - quant_type = scalar_types.uint8 - else: - quant_type = scalar_types.uint8b128 + quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128 w1_ref = w1.clone() w2_ref = w2.clone() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c064b91284ac..2e7f5286008e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -174,16 +174,18 @@ def fused_moe_kernel_gptq_awq( b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ (offs_bn[None, :] // 2) * stride_bzn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk + offs_k_true * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = ((b_zp >> b_zp_shifter) & 0xF) b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ offs_bn[None, :] * stride_bzn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bzk + offs_k_true * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index c1b62293c40e..cb75d48bba37 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -2,11 +2,11 @@ import torch -from vllm.distributed import get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -90,16 +90,16 @@ def is_moe_quant_int_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") - sym = quant_config.get("sym") desc_act = quant_config.get("desc_act") if quant_method == "gptq" and not desc_act and num_bits in [4, 8] and \ GPTQMarlinConfig.is_gptq_marlin_compatible(quant_config): return True - if quant_method == "awq" and num_bits == 4 and \ + elif quant_method == "awq" and num_bits == 4 and \ AWQMarlinConfig.is_awq_marlin_compatible(quant_config): return True - return False + else: + return False def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -341,7 +341,8 @@ def moe_quant_int_weight_loader(param: torch.nn.Parameter, loaded_weight = loaded_weight.T # repeat the qzeros/scales to fit new group size - if layer.group_size_div_factor > 1 and "qzeros" in weight_name or "scales" in weight_name: + if layer.group_size_div_factor > 1 and \ + "qzeros" in weight_name or "scales" in weight_name: loaded_weight = loaded_weight.repeat_interleave( layer.group_size_div_factor, 1) From fb7bba5135824d54e08b0323277c0223a34e48e5 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 19:39:47 +0800 Subject: [PATCH 04/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 15 +++++++++------ .../layers/quantization/moe_quant_int.py | 7 ++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2e7f5286008e..1f23741f48f9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -136,9 +136,9 @@ def fused_moe_kernel_gptq_awq( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn if not has_zp and use_int4_w8a16: - b_zp = 8 + b_zp_num = 8 if not has_zp and use_int8_w8a16: - b_zp = 128 + b_zp_num = 128 elif has_zp and use_int4_w8a16: b_zp_shifter = (offs_bn[None, :] % 2) * 4 @@ -190,7 +190,10 @@ def fused_moe_kernel_gptq_awq( b_zp = b_zp.to(tl.float32) # We accumulate along the K dimension. - b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) accumulator = tl.dot(a, b, acc=accumulator) # Advance the ptrs to the next K block. @@ -579,9 +582,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, C.stride(2), A_scale.stride(0) if ndim(A_scale) == 2 else 0, A_scale.stride(1) if ndim(A_scale) == 2 else 0, - B_scale.stride(0) if ndim(A_scale) >= 2 else 0, - B_scale.stride(2) if ndim(A_scale) == 3 else 0, - B_scale.stride(1) if ndim(A_scale) >= 2 else 0, + B_scale.stride(0) if ndim(B_scale) >= 2 else 0, + B_scale.stride(2) if ndim(B_scale) == 3 else 0, + B_scale.stride(1) if ndim(B_scale) >= 2 else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index cb75d48bba37..5ebc023cbc00 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -29,8 +29,13 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.bit8_pack_factor = 8 // self.weight_bits self.lm_head_quantized = lm_head_quantized self.linear_quant_method = linear_quant_method - self.modules_to_not_convert = modules_to_not_convert + if modules_to_not_convert is None: + self.modules_to_not_convert = [] + else: + self.modules_to_not_convert = modules_to_not_convert + + self.linear_quant_config: GPTQMarlinConfig | AWQMarlinConfig if self.linear_quant_method == "gptq": self.linear_quant_config = GPTQMarlinConfig.from_config( full_config) From 4bd2c3111a8d5ef0820fd76b8ccb5b8d5ddbd6ba Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 19:44:40 +0800 Subject: [PATCH 05/22] fix format error Signed-off-by: Jinzhen Lin --- .../layers/quantization/moe_quant_int.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index 5ebc023cbc00..8a974d5d5929 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -29,6 +29,7 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.bit8_pack_factor = 8 // self.weight_bits self.lm_head_quantized = lm_head_quantized self.linear_quant_method = linear_quant_method + self.full_config = full_config if modules_to_not_convert is None: self.modules_to_not_convert = [] @@ -97,14 +98,14 @@ def is_moe_quant_int_compatible(cls, quant_config: Dict[str, Any]): num_bits = quant_config.get("bits") desc_act = quant_config.get("desc_act") - if quant_method == "gptq" and not desc_act and num_bits in [4, 8] and \ - GPTQMarlinConfig.is_gptq_marlin_compatible(quant_config): - return True - elif quant_method == "awq" and num_bits == 4 and \ - AWQMarlinConfig.is_awq_marlin_compatible(quant_config): - return True - else: - return False + gptq_compatible = quant_method == "gptq" and \ + not desc_act and num_bits in [4, 8] and \ + GPTQMarlinConfig.is_gptq_marlin_compatible(quant_config) + awq_compatible = quant_method == "awq" and \ + num_bits == 4 and \ + AWQMarlinConfig.is_awq_marlin_compatible(quant_config) + + return gptq_compatible or awq_compatible def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -112,9 +113,11 @@ def get_quant_method(self, layer: torch.nn.Module, return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): if self.linear_quant_method == "gptq": - return GPTQMarlinLinearMethod(self.linear_quant_config) + config = GPTQMarlinConfig.from_config(self.full_config) + return GPTQMarlinLinearMethod(config) elif self.linear_quant_method == "awq": - return AWQMarlinLinearMethod(self.linear_quant_config) + config = AWQMarlinConfig.from_config(self.full_config) + return AWQMarlinLinearMethod(config) else: raise ValueError("moe_quant_int only support gptq and awq.") elif isinstance(layer, FusedMoE): From ac8ae24e032e50f4bb65d95bc26f2db7174fda10 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 19:59:16 +0800 Subject: [PATCH 06/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 15 ++++++++++----- .../layers/quantization/moe_quant_int.py | 13 ++----------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1f23741f48f9..4a44f734c362 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -580,11 +580,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - A_scale.stride(0) if ndim(A_scale) == 2 else 0, - A_scale.stride(1) if ndim(A_scale) == 2 else 0, - B_scale.stride(0) if ndim(B_scale) >= 2 else 0, - B_scale.stride(2) if ndim(B_scale) == 3 else 0, - B_scale.stride(1) if ndim(B_scale) >= 2 else 0, + A_scale.stride(0) + if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) + if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) + if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) + if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) + if B_scale is not None and B_scale.ndim >= 2 else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index 8a974d5d5929..28c6fbba23de 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -5,8 +5,8 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -36,15 +36,6 @@ def __init__(self, linear_quant_method: str, weight_bits: int, else: self.modules_to_not_convert = modules_to_not_convert - self.linear_quant_config: GPTQMarlinConfig | AWQMarlinConfig - if self.linear_quant_method == "gptq": - self.linear_quant_config = GPTQMarlinConfig.from_config( - full_config) - elif self.linear_quant_method == "awq": - self.linear_quant_config = AWQMarlinConfig.from_config(full_config) - else: - raise ValueError("MoeQuantInt only support gptq now.") - @classmethod def get_name(cls) -> str: return "moe_quant_int" From 87e191fe25244d8091ffca13b21bc067d6c79dfd Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 20:05:46 +0800 Subject: [PATCH 07/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 -- vllm/model_executor/layers/quantization/moe_quant_int.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4a44f734c362..4b36262c0ca5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -557,8 +557,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ) else: - ndim = lambda x: 0 if x is None else x.ndim - fused_moe_kernel[grid]( A, B, diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index 28c6fbba23de..42c13b6920f6 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -104,11 +104,11 @@ def get_quant_method(self, layer: torch.nn.Module, return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): if self.linear_quant_method == "gptq": - config = GPTQMarlinConfig.from_config(self.full_config) - return GPTQMarlinLinearMethod(config) + gptq_config = GPTQMarlinConfig.from_config(self.full_config) + return GPTQMarlinLinearMethod(gptq_config) elif self.linear_quant_method == "awq": - config = AWQMarlinConfig.from_config(self.full_config) - return AWQMarlinLinearMethod(config) + awq_config = AWQMarlinConfig.from_config(self.full_config) + return AWQMarlinLinearMethod(awq_config) else: raise ValueError("moe_quant_int only support gptq and awq.") elif isinstance(layer, FusedMoE): From 29df4d082d0a66a5dbc02b016de38b70a8c13fb3 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 20:11:05 +0800 Subject: [PATCH 08/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/__init__.py | 2 +- vllm/model_executor/layers/quantization/moe_quant_int.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 6f1fdf83fbdb..07601af48c4a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -56,10 +56,10 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .ipex_quant import IPEXConfig from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config + from .moe_quant_int import MoeQuantIntConfig from .neuron_quant import NeuronQuantConfig from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig - from .moe_quant_int import MoeQuantIntConfig method_to_config: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index 42c13b6920f6..3aa99b07631c 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -7,12 +7,12 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import ( LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig, GPTQMarlinLinearMethod) -from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, AWQMarlinLinearMethod) from vllm.model_executor.utils import set_weight_attrs From 99f23f29c2bc7382f3c6b5e065064a0aecec20fe Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 18 Jan 2025 20:13:26 +0800 Subject: [PATCH 09/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_quant_int.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index 3aa99b07631c..bb6446de6f01 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -5,8 +5,8 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import ( - LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, AWQMarlinLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( From 15ae02b0fc461a2f622a5a1ec86e7829f509f759 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 19 Jan 2025 14:43:26 +0800 Subject: [PATCH 10/22] fix error Signed-off-by: Jinzhen Lin --- tests/kernels/test_moe.py | 13 ++++++++----- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 +-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 41aa6ec1601e..4d9de33ff6a0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -116,15 +116,18 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T - qzeros = qzeros.T.contiguous().to(torch.uint8) + if has_zp: + qzeros = qzeros.T.contiguous().to(torch.uint8) if weight_bits == 4: qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] - qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] + if has_zp: + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] w_ref[expert_id] = weight w_qweight[expert_id] = qweight w_scales[expert_id] = scales - w_qzeros[expert_id] = qzeros + if has_zp: + w_qzeros[expert_id] = qzeros triton_output = fused_moe(a, w1_qweight, @@ -136,8 +139,8 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales, w2_scale=w2_scales, - w1_zp=w1_qzeros, - w2_zp=w2_qzeros, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4b36262c0ca5..eda3ff96bc87 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -457,8 +457,7 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad, - num_experts >= 256) + expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad From ed878d92581c34739b7c5fb137bae8853aac5d2b Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 21 Jan 2025 10:06:25 +0800 Subject: [PATCH 11/22] fix use_int4_w4a16 typo Signed-off-by: Jinzhen Lin --- tests/kernels/test_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 60 +++++++++---------- .../layers/quantization/moe_quant_int.py | 2 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 4d9de33ff6a0..fececa538735 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -135,7 +135,7 @@ def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, score, topk, renormalize=False, - use_int4_w8a16=weight_bits == 4, + use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales, w2_scale=w2_scales, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index eda3ff96bc87..dbb6c2ce4649 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -64,7 +64,7 @@ def fused_moe_kernel_gptq_awq( top_k: tl.constexpr, compute_type: tl.constexpr, has_zp: tl.constexpr, - use_int4_w8a16: tl.constexpr, + use_int4_w4a16: tl.constexpr, use_int8_w8a16: tl.constexpr): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -127,7 +127,7 @@ def fused_moe_kernel_gptq_awq( off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - if use_int4_w8a16: + if use_int4_w4a16: b_ptrs = b_ptr + off_experts * stride_be + \ (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn b_shifter = (offs_k[:, None] % 2) * 4 @@ -135,11 +135,11 @@ def fused_moe_kernel_gptq_awq( b_ptrs = b_ptr + off_experts * stride_be + \ offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - if not has_zp and use_int4_w8a16: + if not has_zp and use_int4_w4a16: b_zp_num = 8 if not has_zp and use_int8_w8a16: b_zp_num = 128 - elif has_zp and use_int4_w8a16: + elif has_zp and use_int4_w4a16: b_zp_shifter = (offs_bn[None, :] % 2) * 4 # ----------------------------------------------------------- @@ -164,7 +164,7 @@ def fused_moe_kernel_gptq_awq( (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs) - if use_int4_w8a16: + if use_int4_w4a16: b = (b >> b_shifter) & 0xF b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ @@ -173,7 +173,7 @@ def fused_moe_kernel_gptq_awq( b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) - if has_zp and use_int4_w8a16: + if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ (offs_bn[None, :] // 2) * stride_bzn + \ @@ -198,7 +198,7 @@ def fused_moe_kernel_gptq_awq( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak - if use_int4_w8a16: + if use_int4_w4a16: b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk else: b_ptrs += BLOCK_SIZE_K * stride_bk @@ -478,7 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - use_int4_w8a16: bool, + use_int4_w4a16: bool, block_shape: Optional[List[int]] = None) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -494,7 +494,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w8a16: + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 else: @@ -512,7 +512,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if (use_int8_w8a16 or use_int4_w8a16) and \ + if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -550,7 +550,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, top_k=top_k, compute_type=compute_type, has_zp=B_zp is not None, - use_int4_w8a16=use_int4_w8a16, + use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, **config, ) @@ -791,14 +791,14 @@ def grouped_topk(hidden_states: torch.Tensor, def get_config_dtype_str(dtype: torch.dtype, - use_int4_w8a16: Optional[bool] = False, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" - elif use_int4_w8a16: + elif use_int4_w4a16: return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE @@ -814,7 +814,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -823,7 +823,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, use_int4_w8a16, w1_scale, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -835,7 +835,7 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -862,7 +862,7 @@ def outplace_fused_experts( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -872,7 +872,7 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, - use_int4_w8a16, w1_scale, w2_scale, w1_zp, w2_zp, + use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -884,7 +884,7 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -911,7 +911,7 @@ def fused_experts(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -923,14 +923,14 @@ def fused_experts(hidden_states: torch.Tensor, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, use_int8_w8a16, - use_int4_w8a16, w1_scale, + use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) return hidden_states else: return torch.ops.vllm.outplace_fused_experts( hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w8a16, w1_scale, w2_scale, w1_zp, w2_zp, + use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -942,7 +942,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -951,7 +951,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None): # Check constraints. - if use_int4_w8a16: + if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: @@ -973,7 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, M = min(num_tokens, CHUNK_SIZE) config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w8a16=use_int4_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) get_config_func = functools.partial( @@ -1054,7 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w8a16=use_int4_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1077,7 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w8a16=use_int4_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), @@ -1099,7 +1099,7 @@ def fused_moe( custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, - use_int4_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1131,7 +1131,7 @@ def fused_moe( - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. - - use_int4_w8a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -1171,7 +1171,7 @@ def fused_moe( inplace=inplace, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w8a16=use_int4_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=w1_zp, diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_quant_int.py index bb6446de6f01..95b9d71773d7 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_quant_int.py @@ -262,7 +262,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_int4_w8a16=weight_bits == 4, + use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, From 28b49ffa0476c30f716e02fe36a4f71737088e40 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 21 Jan 2025 12:53:52 +0800 Subject: [PATCH 12/22] moe_quant_int -> moe_wna16 Signed-off-by: Jinzhen Lin --- tests/kernels/test_moe.py | 6 ++-- .../layers/quantization/__init__.py | 6 ++-- .../{moe_quant_int.py => moe_wna16.py} | 35 +++++++++---------- 3 files changed, 23 insertions(+), 24 deletions(-) rename vllm/model_executor/layers/quantization/{moe_quant_int.py => moe_wna16.py} (93%) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index fececa538735..7aa248ed1475 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -66,9 +66,9 @@ def test_fused_moe( @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) -def test_fused_moe_quant_int(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype, group_size: int, has_zp: bool, - weight_bits: int): +def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, has_zp: bool, + weight_bits: int): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 07601af48c4a..7a5dea1034e0 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -27,7 +27,7 @@ "neuron_quant", "ipex", "quark", - "moe_quant_int" + "moe_wna16" ] @@ -56,7 +56,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .ipex_quant import IPEXConfig from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config - from .moe_quant_int import MoeQuantIntConfig + from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig @@ -85,7 +85,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, "quark": QuarkConfig, - "moe_quant_int": MoeQuantIntConfig, + "moe_wna16": MoeWNA16Config, } return method_to_config[quantization] diff --git a/vllm/model_executor/layers/quantization/moe_quant_int.py b/vllm/model_executor/layers/quantization/moe_wna16.py similarity index 93% rename from vllm/model_executor/layers/quantization/moe_quant_int.py rename to vllm/model_executor/layers/quantization/moe_wna16.py index 95b9d71773d7..02f8f520ae95 100644 --- a/vllm/model_executor/layers/quantization/moe_quant_int.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -16,7 +16,7 @@ from vllm.model_executor.utils import set_weight_attrs -class MoeQuantIntConfig(QuantizationConfig): +class MoeWNA16Config(QuantizationConfig): """Config class for Int8 experts quantization.""" def __init__(self, linear_quant_method: str, weight_bits: int, @@ -38,7 +38,7 @@ def __init__(self, linear_quant_method: str, weight_bits: int, @classmethod def get_name(cls) -> str: - return "moe_quant_int" + return "moe_wna16" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -53,7 +53,7 @@ def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MoeQuantIntConfig": + def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) @@ -67,7 +67,7 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeQuantIntConfig": modules_to_not_convert = cls.get_from_keys( config, ["modules_to_not_convert"]) else: - raise ValueError("moe_quant_int only support gptq and awq.") + raise ValueError("moe_wna16 only support gptq and awq.") return cls(linear_quant_method, weight_bits, group_size, has_zp, lm_head_quantized, modules_to_not_convert, config) @@ -75,15 +75,14 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeQuantIntConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - can_convert = cls.is_moe_quant_int_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None - or user_quant == "moe_quant_int") + can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "moe_wna16") if can_convert and is_valid_user_quant: return cls.get_name() return None @classmethod - def is_moe_quant_int_compatible(cls, quant_config: Dict[str, Any]): + def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") @@ -110,9 +109,9 @@ def get_quant_method(self, layer: torch.nn.Module, awq_config = AWQMarlinConfig.from_config(self.full_config) return AWQMarlinLinearMethod(awq_config) else: - raise ValueError("moe_quant_int only support gptq and awq.") + raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): - return MoeQuantIntMethod(self) + return MoeWNA16Method(self) return None @@ -120,9 +119,9 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) -class MoeQuantIntMethod(FusedMoEMethodBase): +class MoeWNA16Method(FusedMoEMethodBase): - def __init__(self, quant_config: MoeQuantIntConfig): + def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -152,7 +151,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, assert 'weight_loader' in extra_weight_attrs weight_loader = extra_weight_attrs['weight_loader'] - wrapped_weight_loader = MoeQuantIntMethod.get_weight_loader( + wrapped_weight_loader = MoeWNA16Method.get_weight_loader( layer, weight_loader) extra_weight_attrs['weight_loader'] = wrapped_weight_loader @@ -300,10 +299,10 @@ def convert_gptq_int4_qzeros(tensor): tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 return tensor - def moe_quant_int_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, - expert_id: int): + def moe_wna16_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int): if "g_idx" in weight_name: return if not layer.quant_config.has_zp and "qzeros" in weight_name: @@ -359,4 +358,4 @@ def moe_quant_int_weight_loader(param: torch.nn.Parameter, weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) - return moe_quant_int_weight_loader + return moe_wna16_weight_loader From 218f31c70d0b7eda92798b3918e32e4eaf2c4e13 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 21 Jan 2025 13:39:39 +0800 Subject: [PATCH 13/22] add comment for gptq/awq weight conversion Signed-off-by: Jinzhen Lin --- .../layers/quantization/moe_wna16.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 02f8f520ae95..296005e5b6eb 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -273,16 +273,39 @@ def apply( def get_weight_loader(layer, weight_loader): def convert_awq_tensor(tensor, tensor_type): + # convert awq qweight/qzeros to a standard format (assume int4) + # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) + # qzeros: (k // group_size, n // pack_factor_bit32) -> + # (n // pack_factor_bit8, k // group_size) + # pack_factor_bit32 = 32 // weight_bits + # pack_factor_bit8 = 8 // weight_bits + + # 0. suppose origin shape (a, b), dtype int32 + # 1. convert to uint8, shape (a, b) -> (a, 4 * b) size0 = tensor.size(0) tensor = tensor.view(torch.uint8) + + # 2. unpack to uint4 (only when weight_bits == 4) + # shape (a, 4 * b) -> (a, 4 * b, 2) shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF - tensor = tensor.view(-1, - 8)[:, - [0, 4, 1, 5, 2, 6, 3, 7]].view(size0, -1) + + # 3. change order, see + # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py + # shape -> (a, 4 * b * pack_factor_bit8) + reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] + tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] + tensor = tensor.view(size0, -1) + + # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) tensor = tensor.T.contiguous() + + # 5. repack (only when weight_bits == 4) + # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) + # qzeros shape -> (4 * b, a) + if tensor_type == "qweight": tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] elif tensor_type == "qzeros": @@ -329,6 +352,7 @@ def moe_wna16_weight_loader(param: torch.nn.Parameter, loaded_weight = loaded_weight.T.contiguous().view( torch.uint8) elif "zeros" in weight_name: + # add 1 to gptq qzeros to align with awq loaded_weight = loaded_weight.view(torch.uint8) if layer.quant_config.weight_bits == 4: loaded_weight = convert_gptq_int4_qzeros( From d554e8e7dfe8d588f3a197a55dacab75279ba9db Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 22 Jan 2025 19:52:22 +0800 Subject: [PATCH 14/22] support sm70 (gptq) and sm75 (gptq/awq) Signed-off-by: Jinzhen Lin --- .../layers/quantization/moe_wna16.py | 58 +++++++++++++++---- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 296005e5b6eb..591d0ecc513c 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,13 +7,18 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + AWQLinearMethod) from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, AWQMarlinLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.gptq import (GPTQConfig, + GPTQLinearMethod) from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig, GPTQMarlinLinearMethod) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform class MoeWNA16Config(QuantizationConfig): @@ -30,6 +35,25 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.lm_head_quantized = lm_head_quantized self.linear_quant_method = linear_quant_method self.full_config = full_config + self.use_marlin = False + if self.linear_quant_method == "gptq": + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( + full_config) + elif self.linear_quant_method == "awq": + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + if device_capability < awq_min_capability: + raise ValueError( + "The quantization method moe_wna16 + awq is not supported " + "for the current GPU. " + f"Minimum capability: {awq_min_capability}. " + f"Current capability: {device_capability}.") + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( + full_config) + else: + raise ValueError("moe_wna16 only support gptq and awq.") if modules_to_not_convert is None: self.modules_to_not_convert = [] @@ -46,7 +70,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 80 + return 70 @classmethod def get_config_filenames(cls) -> List[str]: @@ -88,12 +112,15 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): num_bits = quant_config.get("bits") desc_act = quant_config.get("desc_act") + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] and \ - GPTQMarlinConfig.is_gptq_marlin_compatible(quant_config) - awq_compatible = quant_method == "awq" and \ - num_bits == 4 and \ - AWQMarlinConfig.is_awq_marlin_compatible(quant_config) + not desc_act and num_bits in [4, 8] + awq_compatible = quant_method == "awq" and num_bits == 4 and \ + device_capability >= awq_min_capability return gptq_compatible or awq_compatible @@ -102,12 +129,19 @@ def get_quant_method(self, layer: torch.nn.Module, if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): - if self.linear_quant_method == "gptq": - gptq_config = GPTQMarlinConfig.from_config(self.full_config) - return GPTQMarlinLinearMethod(gptq_config) - elif self.linear_quant_method == "awq": - awq_config = AWQMarlinConfig.from_config(self.full_config) - return AWQMarlinLinearMethod(awq_config) + method_map = { + # key: (quant_method, use_marlin) + # value: (QuantizationConfig, QuantizationLinearMethod) + ("gptq", True): (GPTQMarlinConfig, GPTQMarlinLinearMethod), + ("gptq", False): (GPTQConfig, GPTQLinearMethod), + ("awq", True): (AWQMarlinConfig, AWQMarlinLinearMethod), + ("awq", False): (AWQConfig, AWQLinearMethod) + } + + if (self.linear_quant_method, self.use_marlin) in method_map: + quant_config_cls, quant_method_cls = method_map[( + self.linear_quant_method, self.use_marlin)] + return quant_method_cls(quant_config_cls(self.full_config)) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): From 8d43a53c10a68b14fd3f7cada854eddcd5424778 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 23 Jan 2025 00:38:39 +0800 Subject: [PATCH 15/22] fix bug Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 591d0ecc513c..0ddd9bf19a8a 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -141,7 +141,8 @@ def get_quant_method(self, layer: torch.nn.Module, if (self.linear_quant_method, self.use_marlin) in method_map: quant_config_cls, quant_method_cls = method_map[( self.linear_quant_method, self.use_marlin)] - return quant_method_cls(quant_config_cls(self.full_config)) + return quant_method_cls( + quant_config_cls.from_config(self.full_config)) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): From b56f6a0acff0eb2191a1b75148a01a7559323038 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 23 Jan 2025 01:06:10 +0800 Subject: [PATCH 16/22] fix mypy error Signed-off-by: Jinzhen Lin --- .../layers/quantization/moe_wna16.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 0ddd9bf19a8a..75113fe3fa19 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -129,20 +129,20 @@ def get_quant_method(self, layer: torch.nn.Module, if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): - method_map = { - # key: (quant_method, use_marlin) - # value: (QuantizationConfig, QuantizationLinearMethod) - ("gptq", True): (GPTQMarlinConfig, GPTQMarlinLinearMethod), - ("gptq", False): (GPTQConfig, GPTQLinearMethod), - ("awq", True): (AWQMarlinConfig, AWQMarlinLinearMethod), - ("awq", False): (AWQConfig, AWQLinearMethod) - } - - if (self.linear_quant_method, self.use_marlin) in method_map: - quant_config_cls, quant_method_cls = method_map[( - self.linear_quant_method, self.use_marlin)] - return quant_method_cls( - quant_config_cls.from_config(self.full_config)) + if self.linear_quant_method == "gptq": + if self.use_marlin: + return GPTQMarlinLinearMethod( + GPTQMarlinConfig.from_config(self.full_config)) + else: + return GPTQLinearMethod( + GPTQConfig.from_config(self.full_config)) + elif self.linear_quant_method == "awq": + if self.use_marlin: + return AWQMarlinLinearMethod( + AWQMarlinConfig.from_config(self.full_config)) + else: + return AWQLinearMethod( + AWQConfig.from_config(self.full_config)) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): From 78d7035b755180dcf8fec1ca7ec3e0d5f13e3040 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 26 Jan 2025 10:16:08 +0800 Subject: [PATCH 17/22] make compliable with main Signed-off-by: Jinzhen Lin --- .../layers/quantization/moe_wna16.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 75113fe3fa19..956ce8cea814 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -160,7 +160,7 @@ def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): layer.quant_config = self.quant_config @@ -171,7 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # make intermediate_size and hidden_size diviable by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later - while intermediate_size % group_size or hidden_size % group_size: + while intermediate_size_per_partition % group_size or hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 @@ -191,38 +191,39 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, extra_weight_attrs['weight_loader'] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size // - bit8_pack_factor, - dtype=torch.uint8), + w13_qweight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8), requires_grad=False) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size // - bit8_pack_factor, - dtype=torch.uint8), + w2_qweight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8), requires_grad=False) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) - w13_scales = torch.nn.Parameter(torch.zeros(num_experts, - 2 * intermediate_size, - hidden_size // group_size, - dtype=params_dtype), + w13_scales = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - intermediate_size // - group_size, - dtype=params_dtype), + w2_scales = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) @@ -230,7 +231,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, if self.quant_config.has_zp: w13_qzeros = torch.nn.Parameter(torch.zeros( num_experts, - 2 * intermediate_size // bit8_pack_factor, + 2 * intermediate_size_per_partition // bit8_pack_factor, hidden_size // group_size, dtype=torch.uint8), requires_grad=False) @@ -240,7 +241,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w2_qzeros = torch.nn.Parameter(torch.zeros( num_experts, hidden_size // bit8_pack_factor, - intermediate_size // group_size, + intermediate_size_per_partition // group_size, dtype=torch.uint8), requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) From 75121437ab3a4d993d88093759fd906fe6c6d519 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 26 Jan 2025 10:22:45 +0800 Subject: [PATCH 18/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 956ce8cea814..9d760a8f8b2a 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -171,7 +171,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # make intermediate_size and hidden_size diviable by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later - while intermediate_size_per_partition % group_size or hidden_size % group_size: + while intermediate_size_per_partition % group_size or \ + hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 From 8dfb2357702dce1f165709b9edee51f060ff8473 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 26 Jan 2025 11:18:13 +0800 Subject: [PATCH 19/22] fix description Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 9d760a8f8b2a..6fd32de34929 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -22,7 +22,7 @@ class MoeWNA16Config(QuantizationConfig): - """Config class for Int8 experts quantization.""" + """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" def __init__(self, linear_quant_method: str, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool, From 9f944ed5affe637c05da8e6bc4cc112a4939864f Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 26 Jan 2025 14:17:44 +0800 Subject: [PATCH 20/22] add description Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 6fd32de34929..a9e4c399504e 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -155,6 +155,11 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): class MoeWNA16Method(FusedMoEMethodBase): + """Linear method for MOE WNA16 (W8A16/W4A16) quantization. + + Args: + quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. + """ def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config From f125988409d2b7d00ec4581c052a98454e85a2ae Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 29 Jan 2025 18:25:55 +0800 Subject: [PATCH 21/22] fix ci error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index a9e4c399504e..8c8495b1eb45 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -100,8 +100,7 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "moe_wna16") - if can_convert and is_valid_user_quant: + if can_convert and user_quant == "moe_wna16": return cls.get_name() return None @@ -128,7 +127,9 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() - elif isinstance(layer, LinearBase): + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + else: if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinLinearMethod( @@ -145,9 +146,6 @@ def get_quant_method(self, layer: torch.nn.Module, AWQConfig.from_config(self.full_config)) else: raise ValueError("moe_wna16 only support gptq and awq.") - elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) - return None def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): From ea3970282137f8dcc04d3265680a3366408494fe Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 29 Jan 2025 18:32:17 +0800 Subject: [PATCH 22/22] fix format error Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/moe_wna16.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 8c8495b1eb45..8cd9c0a7ef25 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -5,8 +5,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.awq import (AWQConfig, AWQLinearMethod) from vllm.model_executor.layers.quantization.awq_marlin import (