-
Notifications
You must be signed in to change notification settings - Fork 492
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add silu kernel, optimize apply rotary kernek * add ut * lint * fix fused moe
- Loading branch information
Showing
7 changed files
with
195 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from torch import nn | ||
|
||
from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul | ||
|
||
from ..activation import SiluAndMulBuilder, SiluAndMulImpl | ||
|
||
|
||
class TritonSiluAndMulImpl(SiluAndMulImpl): | ||
"""silu + multiple residual fused implementation.""" | ||
|
||
def __init__(self, inplace: bool): | ||
self.inplace = inplace | ||
|
||
def _forward_naive(self, x): | ||
"""forward naive.""" | ||
gate, up = x.chunk(2, -1) | ||
return nn.functional.silu(gate, self.inplace) * up | ||
|
||
def forward(self, x): | ||
"""forward.""" | ||
|
||
if x.size(-1) % 2048 != 0: | ||
return self._forward_naive(x) | ||
|
||
out = None | ||
x_shape = None | ||
if x.dim() != 2: | ||
x_shape = x.shape | ||
x = x.flatten(0, -2) | ||
if self.inplace: | ||
out = x.chunk(2, -1)[0] | ||
|
||
out = silu_and_mul(x, out) | ||
|
||
if x_shape is not None: | ||
out = out.unflatten(0, x_shape[:-1]) | ||
return out | ||
|
||
|
||
class TritonSiluAndMulBuilder(SiluAndMulBuilder): | ||
"""silu and mul implementation builder.""" | ||
|
||
@staticmethod | ||
def build(inplace: bool = False): | ||
"""build.""" | ||
return TritonSiluAndMulImpl(inplace) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
from packaging import version | ||
|
||
TRITON_VERSION = version.parse(triton.__version__) | ||
|
||
if TRITON_VERSION >= version.parse('3.0.0'): | ||
|
||
fast_expf = tl.math.exp | ||
else: | ||
tanh = tl.math.tanh | ||
fast_expf = tl.math.fast_expf | ||
|
||
|
||
@triton.jit | ||
def _silu_and_mul_kernel( | ||
gateup_ptr, | ||
out_ptr, | ||
N: tl.constexpr, | ||
stride_gum: tl.constexpr, | ||
stride_gun: tl.constexpr, | ||
stride_om: tl.constexpr, | ||
stride_on: tl.constexpr, | ||
BLOCK_SIZE_N: tl.constexpr, | ||
): | ||
"""silu and mul kernel.""" | ||
m_id = tl.program_id(0) | ||
|
||
up_ptr = gateup_ptr + N * stride_gun | ||
|
||
offs_n = tl.arange(0, BLOCK_SIZE_N) | ||
gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun | ||
up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun | ||
out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on | ||
|
||
for _ in range(0, N, BLOCK_SIZE_N): | ||
gate = tl.load(gate_ptrs).to(tl.float32) | ||
up = tl.load(up_ptrs).to(tl.float32) | ||
|
||
gate = gate / (1 + fast_expf(-gate)) | ||
out = gate * up | ||
|
||
tl.store(out_ptrs, out) | ||
|
||
gate_ptrs += BLOCK_SIZE_N * stride_gun | ||
up_ptrs += BLOCK_SIZE_N * stride_gun | ||
out_ptrs += BLOCK_SIZE_N * stride_on | ||
|
||
|
||
def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): | ||
"""silu and mul.""" | ||
assert gate_up.dim() == 2 | ||
|
||
M = gate_up.size(0) | ||
N = gate_up.size(-1) // 2 | ||
if out is None: | ||
out_shape = (M, N) | ||
out = gate_up.new_empty(out_shape) | ||
|
||
BLOCK_SIZE_N = min(N, 1024) | ||
num_warps = 4 | ||
num_stages = 2 | ||
grid = (M, ) | ||
_silu_and_mul_kernel[grid](gate_up, | ||
out, | ||
N, | ||
stride_gum=gate_up.stride(0), | ||
stride_gun=gate_up.stride(1), | ||
stride_om=out.stride(0), | ||
stride_on=out.stride(1), | ||
BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
num_warps=num_warps, | ||
num_stages=num_stages) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
import torch | ||
|
||
|
||
class TestSiluAndMul: | ||
|
||
@pytest.fixture | ||
def seqlen(self): | ||
yield 256 | ||
|
||
@pytest.fixture | ||
def feat_size(self): | ||
yield 4096 | ||
|
||
@pytest.fixture | ||
def x(self, seqlen, feat_size): | ||
yield torch.rand(seqlen, feat_size, dtype=torch.float16, device='cuda') | ||
|
||
@pytest.fixture | ||
def gt(self, x): | ||
gate, up = x.chunk(2, -1) | ||
gate = torch.nn.functional.silu(gate) | ||
yield gate * up | ||
|
||
def test_silu_and_mul(self, x, gt): | ||
from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul | ||
|
||
out = silu_and_mul(x) | ||
torch.testing.assert_close(out, gt) |