Skip to content

Commit

Permalink
[Kernel] Enable 8-bit weights in Fused Marlin MoE (vllm-project#8032)
Browse files Browse the repository at this point in the history
Co-authored-by: Dipika <[email protected]>
  • Loading branch information
2 people authored and Jeffwan committed Sep 19, 2024
1 parent 465580d commit 4678370
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 185 deletions.
537 changes: 389 additions & 148 deletions csrc/moe/marlin_moe_ops.cu

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

#include <torch/all.h>

#include "core/scalar_type.hpp"

torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
8 changes: 5 additions & 3 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");
"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
18 changes: 12 additions & 6 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe(
m: int,
n: int,
Expand All @@ -148,6 +149,7 @@ def test_fused_marlin_moe(
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)

Expand All @@ -161,13 +163,12 @@ def test_fused_marlin_moe(
if group_size in (k, n):
return

quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
Expand Down Expand Up @@ -240,6 +241,7 @@ def test_fused_marlin_moe(
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
Expand All @@ -254,14 +256,16 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
if topk > e:
return
Expand All @@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
if group_size == k:
return

quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
g_idx,
sort_indices,
topk,
renormalize=False)
renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
3 changes: 2 additions & 1 deletion tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
Empty file modified tests/weight_loading/run_model_weight_loading_test.sh
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2),
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
44 changes: 30 additions & 14 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types


def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Expand All @@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand All @@ -48,10 +52,11 @@ def single_marlin_moe(
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]

M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // 2
N = w.shape[2] // (num_bits // 2)

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
Expand All @@ -76,10 +81,13 @@ def single_marlin_moe(
device="cuda",
requires_grad=False)

scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
False)
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)

Expand All @@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand All @@ -122,6 +131,7 @@ def fused_marlin_moe(
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand All @@ -131,13 +141,14 @@ def fused_marlin_moe(
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]

M, K = hidden_states.shape
E = w1.shape[0]
Expand Down Expand Up @@ -165,6 +176,9 @@ def fused_marlin_moe(
device="cuda",
requires_grad=False)

scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
Expand All @@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1,
perm1,
workspace,
scalar_type,
M,
2 * N,
K,
Expand All @@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2,
perm2,
workspace,
scalar_type,
M,
K,
N,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights, topk_ids.to(torch.int32)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


def get_config_dtype_str(dtype: torch.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -38,10 +40,11 @@ def __init__(

if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits == 4):
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for 4 bits")
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
Expand Down Expand Up @@ -292,4 +295,5 @@ def apply(
topk_ids,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
num_bits=self.num_bits,
)
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,5 @@ def apply(
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
8 changes: 1 addition & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]

if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
Expand Down

0 comments on commit 4678370

Please sign in to comment.