Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support moe w8a8 in pytorch engine #2894

Merged
merged 7 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class OpType(Enum):
LinearW4A16 = auto()
SoftmaxTopK = auto()
FusedMoE = auto()
FusedMoEW8A8 = auto()


class OpsBackend(ABC):
Expand Down
98 changes: 96 additions & 2 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

import torch

from lmdeploy.pytorch.kernels.cuda import fused_moe
from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
per_token_quant_int8
from lmdeploy.pytorch.models.q_modules import QTensor

from ..moe import FusedMoEBuilder, FusedMoEImpl
from ..moe import (FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
FusedMoEW8A8Impl)


class TritonFusedMoEImpl(FusedMoEImpl):
Expand Down Expand Up @@ -74,3 +78,93 @@ def build(top_k: int, num_experts: int, renormalize: bool = False):
return TritonFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize)


class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):
"""triton fused moe w8a8 implementation."""

def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize
self.out_dtype = out_dtype

def update_weights(self, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
down_scale: torch.Tensor):
gate_up_weights = gate_up_weights.transpose(1,
2).contiguous().transpose(
1, 2)
down_weights = down_weights.transpose(1,
2).contiguous().transpose(1, 2)
return gate_up_weights, down_weights, gate_up_scale, down_scale

def support_ep(self):
"""support expert parallelism."""
return True

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
"""forward."""

if isinstance(hidden_states, torch.Tensor):
hidden_states = hidden_states.contiguous()
input_quant, input_scale = per_token_quant_int8(
hidden_states, 1e-7)
else:
assert isinstance(hidden_states, QTensor)
input_quant, input_scale = (hidden_states.tensor,
hidden_states.scale)

expert_offset = 0
num_experts = None
if expert_list is not None and len(expert_list) != self.num_experts:
expert_offset = expert_list[0]
num_experts = self.num_experts
return fused_moe_w8a8(input_quant,
input_scale,
gate_up_weights,
gate_up_scale,
down_weights,
down_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
out_dtype=self.out_dtype,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize)


class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):
"""triton fused moe w8a8 builder."""

@staticmethod
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
"""build from mlp."""
return TritonFusedMoEW8A8Impl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize,
out_dtype=out_dtype)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.FusedMoE:
from .moe import TritonFusedMoEBuilder
return TritonFusedMoEBuilder
elif layer_type == OpType.FusedMoEW8A8:
from .moe import TritonFusedMoEW8A8Builder
return TritonFusedMoEW8A8Builder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
Expand Down
10 changes: 7 additions & 3 deletions lmdeploy/pytorch/backends/cuda/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def build(hidden_size: int, eps: float = 1e-6):
class TritonLinearW8A8Impl(LinearW8A8Impl):
"""triton linear w8a8 implementation."""

def __init__(self, in_features: int, out_features: int):
def __init__(self,
in_features: int,
out_features: int,
out_dtype: torch.dtype = torch.float16):
self.in_features = in_features
self.out_features = out_features
self.out_dtype = out_dtype

def forward(self,
x,
Expand All @@ -70,7 +74,7 @@ def forward(self,
weight,
input_scale,
scale,
output_dtype=torch.float16,
output_dtype=self.out_dtype,
bias=bias)

if all_reduce:
Expand All @@ -87,4 +91,4 @@ def build(in_features: int,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
return TritonLinearW8A8Impl(in_features, out_features)
return TritonLinearW8A8Impl(in_features, out_features, dtype)
45 changes: 45 additions & 0 deletions lmdeploy/pytorch/backends/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,48 @@ class FusedMoEBuilder(ABC):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
raise NotImplementedError


class FusedMoEW8A8Impl(ABC):
"""fused moe w8a8 implementation."""

def update_weights(self, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
down_scale: torch.Tensor):
"""update weights."""
return gate_up_weights, down_weights, gate_up_scale, down_scale

def support_ep(self):
"""support expert parallelism."""
return False

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
raise NotImplementedError('Not Implemented.')

@abstractmethod
def forward(self,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
raise NotImplementedError


class FusedMoEW8A8Builder(ABC):
"""fused moe w8a8 builder."""

@staticmethod
@abstractmethod
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
"""build from mlp."""
raise NotImplementedError
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
from .w8a8_fused_moe import fused_moe_w8a8
from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant,
per_channel_quant, per_token_quant_int8,
rms_norm_dynamic_quant)
Expand All @@ -28,4 +29,5 @@
'rms_norm_dynamic_quant',
'flash_attention_fwd',
'flatten_kv_cache',
'fused_moe_w8a8',
]
67 changes: 39 additions & 28 deletions lmdeploy/pytorch/kernels/cuda/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def fused_moe_kernel(
if GROUP_SIZE_M == 1:
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# pid_m = pid // num_pid_n
# pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
Expand Down Expand Up @@ -133,7 +131,7 @@ def fused_moe_kernel(
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b)
accumulator = tl.dot(a, b, acc=accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

Expand Down Expand Up @@ -271,6 +269,33 @@ def get_start_end(topk_idx: torch.Tensor, sorted_idx: torch.Tensor,
return exp_start, exp_end


def _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int):
"""get sorted idx."""
flatten_topk_ids = topk_ids.flatten()
sorted_idx = flatten_topk_ids.argsort()

exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
num_experts)
return sorted_idx, exp_start, exp_end


def _renormalize(topk_weights: torch.Tensor, renormalize: bool):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if not topk_weights.is_contiguous():
topk_weights = topk_weights.contiguous()
return topk_weights


def _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device,
zeros: bool):
"""make intermediate."""
if zeros:
return torch.zeros(shape, dtype=dtype, device=device)
else:
return torch.empty(shape, dtype=dtype, device=device)


def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -283,31 +308,17 @@ def fused_moe(hidden_states: torch.Tensor,
"""fused moe."""
M = hidden_states.size(0)
E, N, _ = w1.shape
full_exp = False
if num_experts is None:
num_experts = E
elif num_experts == E:
full_exp = True

def __get_sorted_idx(topk_ids: torch.Tensor):
flatten_topk_ids = topk_ids.flatten()
sorted_idx = flatten_topk_ids.argsort()

exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
num_experts)
return sorted_idx, exp_start, exp_end

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if not topk_weights.is_contiguous():
topk_weights = topk_weights.contiguous()
full_exp = num_experts == E

sorted_idx, exp_start, exp_end = __get_sorted_idx(topk_ids)
topk_weights = _renormalize(topk_weights, renormalize)
sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)

if full_exp:
intermediate_cache1 = hidden_states.new_empty((M, topk, N))
else:
intermediate_cache1 = hidden_states.new_zeros((M, topk, N))
intermediate_cache1 = _make_intermediate((M, topk, N),
dtype=hidden_states.dtype,
device=hidden_states.device,
zeros=not full_exp)
# gate and up
fused_moe_kernel_launcher(
hidden_states,
Expand All @@ -331,10 +342,10 @@ def __get_sorted_idx(topk_ids: torch.Tensor):
gate_cache = silu_and_mul(intermediate_cache1)
gate_cache = gate_cache.unflatten(0, unflat_size)

if full_exp:
intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1]))
else:
intermediate_cache2 = hidden_states.new_zeros((M, topk, w2.shape[1]))
intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
zeros=not full_exp)
# down
fused_moe_kernel_launcher(
gate_cache,
Expand Down
Loading
Loading