Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor for multi backends in dlinfer #2619

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ascend import AscendOpsBackend # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import torch
from torch import Tensor

from lmdeploy.pytorch.kernels.ascend import apply_rotary_pos_emb
from lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb

from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl


class AscendApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl):
"""Apply rotary embedding implementation."""

def forward(self,
Expand All @@ -26,10 +26,10 @@ def forward(self,
return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)


class AscendApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):
class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):
"""Apply rotary embedding implementation builder."""

@staticmethod
def build():
"""build implementation."""
return AscendApplyRotaryEmbImpl()
return DlinferApplyRotaryEmbImpl()
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,19 @@

from lmdeploy.utils import get_logger

from ..base import OpType
from ..default import DefaultOpsBackend
from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class AscendOpsBackend(DefaultOpsBackend):
class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""

@staticmethod
def get_name() -> str:
"""backend name."""
return 'ascend'

@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get ascend layer builder."""
if layer_type == OpType.Attention:
from .attention import AscendAttentionBuilder
return AscendAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import AscendApplyRotaryEmbBuilder
return AscendApplyRotaryEmbBuilder
elif layer_type == OpType.RMSNorm:
from .norm import AscendRMSNormBuilder
return AscendRMSNormBuilder
elif layer_type == OpType.SoftmaxTopK:
from .moe import AscendSoftmaxTopKBuilder
return AscendSoftmaxTopKBuilder
elif layer_type == OpType.FusedMoE:
from .moe import AscendFusedMoEBuilder
return AscendFusedMoEBuilder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
return super().get_layer_impl_builder(layer_type)

@staticmethod
def get_attention_metadata_cls():
from .attention import AscendAttentionMetadata
return AscendAttentionMetadata

@staticmethod
def get_k_block_shape(
block_size: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class AscendAttentionMetadata(AttentionMetadata):
class DlinferAttentionMetadata(AttentionMetadata):
kv_start_indices: Optional[Tensor] = None
block_size: int = 64
attention_mask: Sequence[Tensor] = tuple()
Expand All @@ -17,8 +17,8 @@ class AscendAttentionMetadata(AttentionMetadata):
max_kv_seq_len: int = 1


class AscendAttentionImpl(AttentionImpl[AscendAttentionMetadata]):
"""ascend attention implementation."""
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
"""dlinfer attention implementation."""

def __init__(
self,
Expand All @@ -44,8 +44,8 @@ def __init__(
**kwargs,
)

from lmdeploy.pytorch.kernels.ascend import (fill_kv_cache,
paged_attention_fwd)
from lmdeploy.pytorch.kernels.dlinfer import (fill_kv_cache,
paged_attention_fwd)
self.fill_kv_cache = fill_kv_cache
self.paged_attention_fwd = paged_attention_fwd

Expand All @@ -56,7 +56,7 @@ def forward(
value: Tensor,
k_cache: Tensor,
v_cache: Tensor,
attn_metadata: AscendAttentionMetadata,
attn_metadata: DlinferAttentionMetadata,
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
inplace: bool = True,
Expand Down Expand Up @@ -108,8 +108,8 @@ def forward(
return attn_output


class AscendAttentionBuilder(AttentionBuilder[AscendAttentionMetadata]):
"""ascend attention builder."""
class DlinferAttentionBuilder(AttentionBuilder[DlinferAttentionMetadata]):
"""dlinfer attention builder."""

@staticmethod
def build(
Expand All @@ -122,14 +122,14 @@ def build(
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> AscendAttentionImpl:
) -> DlinferAttentionImpl:
"""build."""
return AscendAttentionImpl(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi_scale=alibi_scale,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
**kwargs)
return DlinferAttentionImpl(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi_scale=alibi_scale,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import torch

from lmdeploy.pytorch.kernels.ascend import moe_gating_topk_softmax
from lmdeploy.pytorch.kernels.dlinfer import moe_gating_topk_softmax

from ..moe import (FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder,
SoftmaxTopKImpl)


class AscendSoftmaxTopKImpl(SoftmaxTopKImpl):
"""ascend softmax topk implementation."""
class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl):
"""dlinfer softmax topk implementation."""

def __init__(self, top_k: int, dim: int = -1):
self.top_k = top_k
Expand All @@ -22,17 +22,17 @@ def forward(self, x: torch.Tensor):
torch.int64)


class AscendSoftmaxTopKBuilder(SoftmaxTopKBuilder):
"""ascend softmax topk implementation builder."""
class DlinferSoftmaxTopKBuilder(SoftmaxTopKBuilder):
"""dlinfer softmax topk implementation builder."""

@staticmethod
def build(top_k: int, dim: int = -1):
"""build."""
return AscendSoftmaxTopKImpl(top_k, dim)
return DlinferSoftmaxTopKImpl(top_k, dim)


class AscendFusedMoEImpl(FusedMoEImpl):
"""ascend fused moe implementation."""
class DlinferFusedMoEImpl(FusedMoEImpl):
"""dlinfer fused moe implementation."""

def __init__(self, top_k: int, renormalize: bool = False):
self.top_k = top_k
Expand Down Expand Up @@ -68,10 +68,10 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
return moe_output


class AscendFusedMoEBuilder(FusedMoEBuilder):
"""ascend fused moe builder."""
class DlinferFusedMoEBuilder(FusedMoEBuilder):
"""dlinfer fused moe builder."""

@staticmethod
def build(top_k: int, renormalize: bool = False):
"""build from mlp."""
return AscendFusedMoEImpl(top_k=top_k, renormalize=renormalize)
return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize)
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.pytorch.kernels.ascend import rms_norm
from lmdeploy.pytorch.kernels.dlinfer import rms_norm

from ..norm import RMSNormBuilder, RMSNormImpl


class AscendRMSNormImpl(RMSNormImpl):
"""ascend RMS norm implementation."""
class DlinferRMSNormImpl(RMSNormImpl):
"""dlinfer RMS norm implementation."""

def __init__(self, hidden_size: int, eps: float = 1e-6):
self.hidden_size = hidden_size
Expand All @@ -26,10 +26,10 @@ def forward(self,
return x, residual


class AscendRMSNormBuilder(RMSNormBuilder):
"""ascend RMS norm implementation builder."""
class DlinferRMSNormBuilder(RMSNormBuilder):
"""dlinfer RMS norm implementation builder."""

@staticmethod
def build(weight: torch.Tensor, eps: float = 1e-6):
"""build."""
return AscendRMSNormImpl(weight, eps)
return DlinferRMSNormImpl(weight, eps)
79 changes: 79 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.utils import get_logger

from ..base import OpType
from ..default import DefaultOpsBackend

logger = get_logger('lmdeploy')


class DlinferOpsBackend(DefaultOpsBackend):
"""dlinfer layer backend."""

@staticmethod
def get_name() -> str:
"""backend name."""
return 'dlinfer'

@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get dlinfer layer builder."""
if layer_type == OpType.Attention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
elif layer_type == OpType.RMSNorm:
from .norm import DlinferRMSNormBuilder
return DlinferRMSNormBuilder
elif layer_type == OpType.SoftmaxTopK:
from .moe import DlinferSoftmaxTopKBuilder
return DlinferSoftmaxTopKBuilder
elif layer_type == OpType.FusedMoE:
from .moe import DlinferFusedMoEBuilder
return DlinferFusedMoEBuilder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
return super().get_layer_impl_builder(layer_type)

@staticmethod
def get_attention_metadata_cls():
from .attention import DlinferAttentionMetadata
return DlinferAttentionMetadata

@staticmethod
def get_k_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
block_size,
num_heads,
head_size,
)

@staticmethod
def get_v_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
block_size,
num_heads,
head_size,
)

@classmethod
def update_step_context(cls, step_context):
"""update step context."""
raise NotImplementedError
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_backend():
from .cuda import CudaOpsBackend
return CudaOpsBackend
if device_type == 'ascend':
from .ascend import AscendOpsBackend
from .dlinfer import AscendOpsBackend
return AscendOpsBackend
else:
raise RuntimeError(f'Unsupported device type: {device_type}')
Loading