Skip to content

Commit

Permalink
Add silu mul kernel (#2469)
Browse files Browse the repository at this point in the history
* add silu kernel, optimize apply rotary kernek

* add ut

* lint

* fix fused moe
  • Loading branch information
grimoire authored Sep 19, 2024
1 parent a513526 commit 97449e3
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 33 deletions.
47 changes: 47 additions & 0 deletions lmdeploy/pytorch/backends/cuda/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn

from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul

from ..activation import SiluAndMulBuilder, SiluAndMulImpl


class TritonSiluAndMulImpl(SiluAndMulImpl):
"""silu + multiple residual fused implementation."""

def __init__(self, inplace: bool):
self.inplace = inplace

def _forward_naive(self, x):
"""forward naive."""
gate, up = x.chunk(2, -1)
return nn.functional.silu(gate, self.inplace) * up

def forward(self, x):
"""forward."""

if x.size(-1) % 2048 != 0:
return self._forward_naive(x)

out = None
x_shape = None
if x.dim() != 2:
x_shape = x.shape
x = x.flatten(0, -2)
if self.inplace:
out = x.chunk(2, -1)[0]

out = silu_and_mul(x, out)

if x_shape is not None:
out = out.unflatten(0, x_shape[:-1])
return out


class TritonSiluAndMulBuilder(SiluAndMulBuilder):
"""silu and mul implementation builder."""

@staticmethod
def build(inplace: bool = False):
"""build."""
return TritonSiluAndMulImpl(inplace)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.MultinomialSampling:
from .multinomial_sampling import TritonMultinomialSamplingBuilder
return TritonMultinomialSamplingBuilder
elif layer_type == OpType.SiluAndMul:
from .activation import TritonSiluAndMulBuilder
return TritonSiluAndMulBuilder
elif layer_type == OpType.LinearW4A16:
from awq.modules.linear.gemm import AWQ_INSTALLED
if AWQ_INSTALLED:
Expand Down
77 changes: 77 additions & 0 deletions lmdeploy/pytorch/kernels/cuda/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from packaging import version

TRITON_VERSION = version.parse(triton.__version__)

if TRITON_VERSION >= version.parse('3.0.0'):

fast_expf = tl.math.exp
else:
tanh = tl.math.tanh
fast_expf = tl.math.fast_expf


@triton.jit
def _silu_and_mul_kernel(
gateup_ptr,
out_ptr,
N: tl.constexpr,
stride_gum: tl.constexpr,
stride_gun: tl.constexpr,
stride_om: tl.constexpr,
stride_on: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""silu and mul kernel."""
m_id = tl.program_id(0)

up_ptr = gateup_ptr + N * stride_gun

offs_n = tl.arange(0, BLOCK_SIZE_N)
gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun
up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun
out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on

for _ in range(0, N, BLOCK_SIZE_N):
gate = tl.load(gate_ptrs).to(tl.float32)
up = tl.load(up_ptrs).to(tl.float32)

gate = gate / (1 + fast_expf(-gate))
out = gate * up

tl.store(out_ptrs, out)

gate_ptrs += BLOCK_SIZE_N * stride_gun
up_ptrs += BLOCK_SIZE_N * stride_gun
out_ptrs += BLOCK_SIZE_N * stride_on


def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None):
"""silu and mul."""
assert gate_up.dim() == 2

M = gate_up.size(0)
N = gate_up.size(-1) // 2
if out is None:
out_shape = (M, N)
out = gate_up.new_empty(out_shape)

BLOCK_SIZE_N = min(N, 1024)
num_warps = 4
num_stages = 2
grid = (M, )
_silu_and_mul_kernel[grid](gate_up,
out,
N,
stride_gum=gate_up.stride(0),
stride_gun=gate_up.stride(1),
stride_om=out.stride(0),
stride_on=out.stride(1),
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps,
num_stages=num_stages)

return out
54 changes: 26 additions & 28 deletions lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
half_size=torch.int32,
BLOCK=torch.int32,
BLOCK_QH=torch.int32,
BLOCK_KH=torch.int32,
BLOCK_N=torch.int32,
))
@triton.jit(do_not_specialize=('seq_len', ))
Expand All @@ -58,11 +57,11 @@ def apply_rotary_pos_emb_qk_kernel(
half_size: tl.constexpr,
BLOCK: tl.constexpr,
BLOCK_QH: tl.constexpr,
BLOCK_KH: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""apply rotary on key AND query kernel."""
seq_block_id = tl.program_id(0)
head_id = tl.program_id(1)

pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)
pos_mask = pos_offset < seq_len
Expand All @@ -83,44 +82,44 @@ def apply_rotary_pos_emb_qk_kernel(
sin_l = tl.load(SIN + cs_offset_l).to(q_elem_type)
sin_h = tl.load(SIN + cs_offset_h).to(q_elem_type)

q_ptr = Q + pos_offset * stride_qs
qe_ptr = Q_EMB + pos_offset * stride_qes
ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd
qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd
qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed
qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed
for _ in range(BLOCK_QH):
if head_id < BLOCK_QH:
q_ptr = Q + pos_offset * stride_qs
qe_ptr = Q_EMB + pos_offset * stride_qes
ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd
qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd
qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed
qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed
ql_ptrs += head_id * stride_qh
qh_ptrs += head_id * stride_qh
qel_ptrs += head_id * stride_qeh
qeh_ptrs += head_id * stride_qeh

q_l = tl.load(ql_ptrs)
q_h = tl.load(qh_ptrs)
qe_l = q_l * cos_l - q_h * sin_l
qe_h = q_h * cos_h + q_l * sin_h

tl.store(qel_ptrs, qe_l, mask=seq_mask)
tl.store(qeh_ptrs, qe_h, mask=seq_mask)

ql_ptrs += stride_qh
qh_ptrs += stride_qh
qel_ptrs += stride_qeh
qeh_ptrs += stride_qeh

k_ptr = K + pos_offset * stride_ks
ke_ptr = K_EMB + pos_offset * stride_kes
kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd
kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd
kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked
keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked
for _ in range(BLOCK_KH):
else:
head_id = head_id - BLOCK_QH
k_ptr = K + pos_offset * stride_ks
ke_ptr = K_EMB + pos_offset * stride_kes
kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd
kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd
kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked
keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked
kl_ptrs += head_id * stride_kh
kh_ptrs += head_id * stride_kh
kel_ptrs += head_id * stride_keh
keh_ptrs += head_id * stride_keh
k_l = tl.load(kl_ptrs)
k_h = tl.load(kh_ptrs)
ke_l = k_l * cos_l - k_h * sin_l
ke_h = k_h * cos_h + k_l * sin_h

tl.store(kel_ptrs, ke_l, mask=seq_mask)
tl.store(keh_ptrs, ke_h, mask=seq_mask)
kl_ptrs += stride_kh
kh_ptrs += stride_kh
kel_ptrs += stride_keh
keh_ptrs += stride_keh


def apply_rotary_pos_emb(q: Tensor,
Expand Down Expand Up @@ -162,7 +161,7 @@ def apply_rotary_pos_emb(q: Tensor,
num_stages = 4

kernel_meta = get_kernel_meta(q)
grid = [triton.cdiv(seq_len, BLOCK)]
grid = [triton.cdiv(seq_len, BLOCK), num_heads_q + num_heads_k]
apply_rotary_pos_emb_qk_kernel[grid](q,
k,
cos,
Expand All @@ -185,7 +184,6 @@ def apply_rotary_pos_emb(q: Tensor,
half_size=half_size,
BLOCK=BLOCK,
BLOCK_QH=num_heads_q,
BLOCK_KH=num_heads_k,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
Expand Down
11 changes: 9 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from .activation import silu_and_mul
from .triton_utils import get_kernel_meta, wrap_jit_func


Expand Down Expand Up @@ -351,8 +352,14 @@ def __get_sorted_idx(topk_ids: torch.Tensor):
)

# activate
gate_cache, up_cache = intermediate_cache1.chunk(2, -1)
gate_cache = F.silu(gate_cache, inplace=True) * up_cache
if intermediate_cache1.size(-1) % 2048 == 0:
unflat_size = intermediate_cache1.shape[:-1]
intermediate_cache1 = intermediate_cache1.flatten(0, -2)
gate_cache = silu_and_mul(intermediate_cache1)
gate_cache = gate_cache.unflatten(0, unflat_size)
else:
gate_cache, up_cache = intermediate_cache1.chunk(2, -1)
gate_cache = F.silu(gate_cache, inplace=True) * up_cache

if full_exp:
intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1]))
Expand Down
7 changes: 4 additions & 3 deletions lmdeploy/pytorch/kernels/cuda/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
xf = x.to(tl.float32)

var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
out = xf / tl.sqrt(var + eps)
out = xf * tl.math.rsqrt(var + eps)
out = (w * out).to(x.dtype)
return out

Expand Down Expand Up @@ -75,12 +75,13 @@ def add_rms_norm_kernel(input, weight, residual, output, out_residual,
res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)

new_x = x + res
out_res_ptr = out_residual + prog_id * residual_row_stride
tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)

out = _compute_rms_norm(new_x, w, eps, N_COLS)

out_ptr = output + prog_id * input_row_stride
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
out_res_ptr = out_residual + prog_id * residual_row_stride
tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)


def rms_norm(hidden_states: Tensor,
Expand Down
29 changes: 29 additions & 0 deletions tests/pytorch/kernel/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import torch


class TestSiluAndMul:

@pytest.fixture
def seqlen(self):
yield 256

@pytest.fixture
def feat_size(self):
yield 4096

@pytest.fixture
def x(self, seqlen, feat_size):
yield torch.rand(seqlen, feat_size, dtype=torch.float16, device='cuda')

@pytest.fixture
def gt(self, x):
gate, up = x.chunk(2, -1)
gate = torch.nn.functional.silu(gate)
yield gate * up

def test_silu_and_mul(self, x, gt):
from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul

out = silu_and_mul(x)
torch.testing.assert_close(out, gt)

0 comments on commit 97449e3

Please sign in to comment.