Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ep, column major moe kernel. #2690

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import torch

from lmdeploy.pytorch.kernels.cuda import fused_moe
Expand All @@ -10,7 +12,11 @@
class TritonFusedMoEImpl(FusedMoEImpl):
"""triton fused moe implementation."""

def __init__(self, top_k: int, renormalize: bool = False):
def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize

Expand All @@ -23,23 +29,48 @@ def update_weights(self, gate_up_weights: torch.Tensor,
2).contiguous().transpose(1, 2)
return gate_up_weights, down_weights

def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def support_ep(self):
"""support expert parallelism."""
return True

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
expert_offset = 0
num_experts = None
if expert_list is not None and len(expert_list) != self.num_experts:
expert_offset = expert_list[0]
num_experts = self.num_experts
return fused_moe(hidden_states,
gate_up_weights,
down_weights,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize)


class TritonFusedMoEBuilder(FusedMoEBuilder):
"""triton fused moe builder."""

@staticmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
return TritonFusedMoEImpl(top_k=top_k, renormalize=renormalize)
return TritonFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize)
14 changes: 10 additions & 4 deletions lmdeploy/pytorch/backends/dlinfer/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import torch

from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
Expand Down Expand Up @@ -38,9 +40,13 @@ def __init__(self, top_k: int, renormalize: bool = False):
self.top_k = top_k
self.renormalize = renormalize

def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
gate_up_weights, down_weights)
Expand All @@ -50,6 +56,6 @@ class DlinferFusedMoEBuilder(FusedMoEBuilder):
"""dlinfer fused moe builder."""

@staticmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize)
21 changes: 17 additions & 4 deletions lmdeploy/pytorch/backends/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List

import torch

Expand Down Expand Up @@ -31,10 +32,22 @@ def update_weights(self, gate_up_weights: torch.Tensor,
"""update weights."""
return gate_up_weights, down_weights

def support_ep(self):
"""support expert parallelism."""
return False

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
raise NotImplementedError('Not Implemented.')

@abstractmethod
def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
raise NotImplementedError

Expand All @@ -44,6 +57,6 @@ class FusedMoEBuilder(ABC):

@staticmethod
@abstractmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
raise NotImplementedError
79 changes: 27 additions & 52 deletions lmdeploy/pytorch/kernels/cuda/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
import triton.language as tl

from .activation import silu_and_mul
from .triton_utils import get_kernel_meta, wrap_jit_func
from .triton_utils import get_kernel_meta


def get_cuda_autotune_config():
return [
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 1,
},
num_stages=3,
num_warps=8),
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 1,
},
Expand All @@ -43,34 +43,9 @@ def get_cuda_autotune_config():
@triton.autotune(
configs=get_cuda_autotune_config(),
key=['N', 'K', 'M_NP2'],
warmup=10,
rep=25,
)
@wrap_jit_func(type_hint=dict(
A=torch.Tensor,
B=torch.Tensor,
C=torch.Tensor,
SortedIdx=torch.Tensor,
ExpStart=torch.Tensor,
ExpEnd=torch.Tensor,
Weights=torch.Tensor,
N=int,
K=int,
stride_am=int,
stride_ak=int,
stride_be=int,
stride_bn=int,
stride_bk=int,
stride_cm=int,
stride_cn=int,
BLOCK_SIZE_M=torch.int32,
BLOCK_SIZE_N=torch.int32,
BLOCK_SIZE_K=torch.int32,
GROUP_SIZE_M=torch.int32,
ENABLE_WEIGHTS=bool,
top_k=torch.int32,
expert_offset=torch.int32,
reindex_a=bool,
reindex_c=bool,
))
@triton.jit
def fused_moe_kernel(
A,
Expand Down Expand Up @@ -110,16 +85,23 @@ def fused_moe_kernel(
if M <= 0:
return

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if pid_m * BLOCK_SIZE_M >= M:

if GROUP_SIZE_M == 1:
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# pid_m = pid // num_pid_n
# pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
return

offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand Down Expand Up @@ -189,11 +171,11 @@ def fused_moe_kernel_launcher(
if num_tokens is None:
num_tokens = A.size(0)
M_NP2 = triton.next_power_of_2(num_tokens)
M_NP2 = max(32, M_NP2)
M_NP2 = max(64, M_NP2)
E, N, K = B.shape

def _grid_fn(META):
grid = (triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) *
grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
triton.cdiv(N, META['BLOCK_SIZE_N']), E)
return grid

Expand Down Expand Up @@ -229,13 +211,6 @@ def _grid_fn(META):
)


@wrap_jit_func(type_hint=dict(TopkIdx=torch.Tensor,
SortedIdx=torch.Tensor,
ExpStart=torch.Tensor,
ExpEnd=torch.Tensor,
len_sorted_idx=int,
num_experts=torch.int32,
BLOCK=torch.int32))
@triton.jit
def _start_end_kernel(TopkIdx, SortedIdx, ExpStart, ExpEnd,
len_sorted_idx: int, num_experts: tl.constexpr,
Expand Down
Loading
Loading