Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*

# triton jit
.triton

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
375 changes: 375 additions & 0 deletions tests/kernels/moe/test_gpt_oss_triton_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, fields

import pytest
import torch
import torch.nn.functional as F
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
upcast_from_mxfp)
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close

from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
BatchedOAITritonExperts, triton_kernel_moe_forward)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.utils import shuffle_weight
from vllm.utils import round_up


def deshuffle(w: torch.Tensor):
first = w[..., ::2]
second = w[..., 1::2]

deshuffled = torch.concat((first, second), dim=-1)
return deshuffled


def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
randbits = [torch.randperm(E) for _ in range(M)]
x_list = [
(-1)**i *
((16384 +
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
for i, bits in enumerate(randbits)
]
exp_data = torch.stack(x_list).to(
device="cuda") # simulating gate_output (M, E)

# create input tensor
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")

w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")

exp_data_tri = exp_data.clone()
x_tri = x.clone()
w1_tri = w1.clone()
w2_tri = w2.clone()

w1_bias_tri = w1_bias.clone()
w2_bias_tri = w2_bias.clone()
w1_bias_tri = w1_bias_tri.to(torch.float32)
w2_bias_tri = w2_bias_tri.to(torch.float32)

dtype_dict = {
"bf16": torch.bfloat16,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2
}

x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
if w_dtype != "mx4":
# simulate quantization support on reference impl
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)

# triton moe kernel use transposed shape for matmul
w1_tri = w1_tri.transpose(-2, -1)
w2_tri = w2_tri.transpose(-2, -1)

# shuffle weights
w1_tri = shuffle_weight(w1_tri)
w1_bias_tri = shuffle_weight(w1_bias_tri)

# quant triton_weights
x_tri = x.to(dtype_dict[a_dtype])
if w_dtype != "mx4":
pytest.skip("NYI")
else: # quantize to mx4
# careful on the padding here, the activation padding need to be
# multiple of 64, the actual engine is not implemented
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]

w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad

x_pad = w1_bottom_pad

w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
mode="constant",
value=0)
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
mode="constant",
value=0)

w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
mode="constant",
value=0)
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
mode="constant",
value=0)

x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)

w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1)
w_scale_layout, w_scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))

w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)

w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)

w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
**w_layout_opts)
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
w_scale_layout, **w_scale_layout_opts)

w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
**w_layout_opts)
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
w_scale_layout, **w_scale_layout_opts)

pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))

# tucuate so the rest can run properly
w1 = w1[..., :K, :2 * N]
w2 = w2[..., :N, :K]

w1 = deshuffle(w1)

w1 = w1.transpose(-1, -2).contiguous()
w2 = w2.transpose(-1, -2).contiguous()

return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)


@dataclass
class ModelConfig:
num_hidden_layers: int = 36
num_experts: int = 128
experts_per_token: int = 4
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0


def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
x_glu = x_glu.clamp(max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
if limit is not None:
x_linear = x_linear.clamp(min=-limit, max=limit)
return out_glu * (x_linear + 1)


def oai_moe_forward(
hidden_states: torch.Tensor, # (M, K)
w1: torch.Tensor, # (E, 2N)
w1_bias: torch.Tensor, # (E, 2N, K)
w2: torch.Tensor, # (E, K, N)
w2_bias: torch.Tensor, # (E, N)
gating_output: torch.Tensor, # (M, E)
topk: int):
# model.py 309:330, assuming gating and norm
t = hidden_states
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices

# MLP #1
mlp1_weight = w1[expert_indices, ...]
mlp1_bias = w1_bias[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, limit=7)

# MLP #2
mlp2_weight = w2[expert_indices, ...]
mlp2_bias = w2_bias[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
t += mlp2_bias

# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)

return t


@dataclass
class Case:
a_dtype: str
w_dtype: str


@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
]
],
)
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
M = num_token
E = ModelConfig.num_experts
K = ModelConfig.hidden_size
N = ModelConfig.intermediate_size // tp
topk = ModelConfig.experts_per_token

x, w1, w1_bias, w2, w2_bias, exp_data, \
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
w2_bias_tri, pc1, pc2 = init_compute_data(
M, K, N, E, a_dtype, w_dtype, num_warps=8)

out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,
w1=w1_tri,
w2=w2_tri,
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2)
out_triton_monolithic = out_triton_monolithic[..., :K]

out_ref = oai_moe_forward(hidden_states=x,
w1=w1,
w1_bias=w1_bias,
w2=w2,
w2_bias=w2_bias,
gating_output=exp_data,
topk=topk)
assert_close(ref=out_ref,
tri=out_triton_monolithic,
maxtol=0.025,
rmstol=0.005)


def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
topk: int, renormalize: bool, w1_bias: torch.Tensor,
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
w2_precision: PrecisionConfig) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)

fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedOAITritonExperts(
None,
max_num_tokens=max_num_tokens,
num_dispatchers=1,
w1_precision=w1_precision,
w2_precision=w2_precision,
),
)

extra_expert_args = {
"w1_bias": w1_bias,
"w2_bias": w2_bias,
}

topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)

return fused_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
extra_expert_args=extra_expert_args,
)


@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
]
],
)
@pytest.mark.parametrize("num_token", [64])
@pytest.mark.parametrize("ep", [1, 2, 4, 8])
def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
M = num_token
E = ModelConfig.num_experts // ep
K = ModelConfig.hidden_size
N = ModelConfig.intermediate_size
topk = ModelConfig.experts_per_token

x, w1, w1_bias, w2, w2_bias, exp_data, \
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
w2_bias_tri, pc1, pc2 = init_compute_data(
M, K, N, E, a_dtype, w_dtype, num_warps=4)

out_tri = batched_moe(a=x_tri,
w1=w1_tri,
w2=w2_tri,
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2)
out_tri = out_tri[..., :K]

out_ref = oai_moe_forward(hidden_states=x,
w1=w1,
w1_bias=w1_bias,
w2=w2,
w2_bias=w2_bias,
gating_output=exp_data,
topk=topk)
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)


def test_unit_shuffle():
N = ModelConfig.intermediate_size
K = ModelConfig.hidden_size
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")

x = torch.randn(K, dtype=torch.bfloat16, device="cuda")

m_shuffled = shuffle_weight(m)

out_ref = x @ m
out_ref = swiglu(out_ref, limit=1.0)

out = x @ m_shuffled
out = triton_kernels.swiglu.swiglu_torch(
out,
alpha=1.702,
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))

assert_close(ref=out_ref, tri=out)
Loading