Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Enable 8-bit weights in Fused Marlin MoE #8032

Merged
merged 30 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0abac6f
Enable 8-bit weights in Fused Marlin MoE
ElizaWszola Aug 30, 2024
fdf69c2
fix rocm
ElizaWszola Aug 30, 2024
4da163b
bad paste
ElizaWszola Aug 30, 2024
21d2337
add test case; fix imports for tests
dsikka Aug 30, 2024
080ab23
Merge branch 'main' into marlin-moe-8-bit
dsikka Aug 30, 2024
638777a
fix to adapt custom_routin_function
dsikka Aug 30, 2024
bd4b84d
Use select_experts to compute top_k tensors in fused moe
ElizaWszola Sep 2, 2024
bef6b53
bring back fused_moe_marlin -> fused_marlin_moe
ElizaWszola Sep 3, 2024
befc52b
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 4, 2024
b45594c
remove large model
dsikka Sep 4, 2024
effd2cd
Cleanup, comments
ElizaWszola Sep 4, 2024
52c3353
fix moe init
ElizaWszola Sep 4, 2024
882fd9c
move larger models to an options larger test
dsikka Sep 4, 2024
973d914
add optional flag
dsikka Sep 4, 2024
72bc899
swap gpu
dsikka Sep 5, 2024
eea2bc3
Temp disable part of moe tests to see what's breaking
ElizaWszola Sep 5, 2024
9c29dc2
Fixes to act_order, make unit tests more robust
ElizaWszola Sep 5, 2024
6d04dcd
try to narrow down cuda error
ElizaWszola Sep 5, 2024
83e7999
Try different subset of test params
ElizaWszola Sep 6, 2024
6a42eaf
.
ElizaWszola Sep 6, 2024
3288842
.
ElizaWszola Sep 6, 2024
61ef4ba
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 10, 2024
667d23e
fix and cleanup after merge
ElizaWszola Sep 10, 2024
b16838e
cleanup
ElizaWszola Sep 10, 2024
e53abb9
validate cache for the kernel code
ElizaWszola Sep 10, 2024
2f82715
cleanup commented out code
ElizaWszola Sep 11, 2024
f97b524
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 13, 2024
771f693
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 13, 2024
aac7c20
fix type conversion
ElizaWszola Sep 13, 2024
9d7caad
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 211 additions & 90 deletions csrc/moe/marlin_moe_ops.cu

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

#include <torch/all.h>

#include "core/scalar_type.hpp"

torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
9 changes: 5 additions & 4 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
230 changes: 226 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@

Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List

import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types


def torch_moe(a, w1, w2, score, topk):
Expand All @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)


@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -43,11 +65,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -99,3 +121,203 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []

for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)

w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)

w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []

for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)

w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)

triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []

for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2),
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON

__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
Loading
Loading