Skip to content

Commit

Permalink
add linear op on dlinfer platform (#2627)
Browse files Browse the repository at this point in the history
* add linear op on ascend platform

* update code
  • Loading branch information
yao-fengchen authored Nov 5, 2024
1 parent ed9aa15 commit 364a142
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
32 changes: 32 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch

from lmdeploy.pytorch.kernels.dlinfer import linear

from ..linear import LinearBuilder, LinearImpl


class DlinferLinearImpl(LinearImpl):
"""Dlinfer linear implementation api."""

def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
return linear(x, weight, bias, all_reduce)


class DlinferLinearBuilder(LinearBuilder):
"""Dlinfer linear implementation builder."""

@staticmethod
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
return DlinferLinearImpl()
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.FusedMoE:
from .moe import DlinferFusedMoEBuilder
return DlinferFusedMoEBuilder
elif layer_type == OpType.Linear:
from .linear import DlinferLinearBuilder
return DlinferLinearBuilder
elif layer_type == OpType.LinearW4A16:
from .awq_modules import AwqLinearW4A16Builder
return AwqLinearW4A16Builder
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -15,6 +16,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
]
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import dlinfer.ops as ext_ops
from torch import Tensor


def linear(x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
all_reduce: bool = False):
return ext_ops.linear(x, weight, bias=bias, all_reduce=all_reduce)

0 comments on commit 364a142

Please sign in to comment.