diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 1f14ac92b9..7675956ab9 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -30,7 +30,7 @@ def check_env_deeplink(device_type: str): if device_type in deeplink_device_type_list: logger = get_logger('lmdeploy') try: - import deeplink_ext # noqa: F401 + import dlinfer.framework.lmdeploy_ext # noqa: F401 except Exception as e: _handle_exception(e, 'PyTorch', logger) diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/engine/devices/ascend.py index a09fa5f655..9c782a3f3d 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/engine/devices/ascend.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from .dipu import DIPUDeviceUtils +from .base_device_utils import BaseDeviceUtils -class ASCENDDeviceUtils(DIPUDeviceUtils): +class ASCENDDeviceUtils(BaseDeviceUtils): device = 'ascend' @@ -17,7 +17,8 @@ def update_step_context(cls, step_context): single_attention_mask = torch.logical_not( torch.tril( torch.ones(step_context.q_seq_length[i], - step_context.kv_seq_length[i], + step_context.block_offsets.shape[1] * + block_size, dtype=torch.bool).cuda(), diagonal=step_context.kv_seq_length[i] - step_context.q_seq_length[i], @@ -28,7 +29,7 @@ def update_step_context(cls, step_context): block_loc = step_context.block_offsets[i][block_idx] token_loc = history_length % block_size for _ in range(step_context.q_seq_length[i]): - kv_start_indices.append(block_loc * block_size + token_loc) + kv_start_indices.append([block_loc * block_size + token_loc]) if _ == step_context.q_seq_length[i] - 1: break token_loc = (token_loc + 1) % block_size @@ -38,4 +39,11 @@ def update_step_context(cls, step_context): kv_start_indices, device=step_context.block_offsets.device) setattr(step_context, 'kv_start_indices', kv_start_indices) setattr(step_context, 'attention_mask', attention_mask) + setattr(step_context, 'q_start_loc', step_context.q_start_loc.cpu()) + setattr(step_context, 'q_seq_length', step_context.q_seq_length.cpu()) + setattr(step_context, 'kv_seq_length', + step_context.kv_seq_length.cpu()) + is_unpaged_prefill = (not step_context.is_decoding) and all( + (step_context.q_seq_length == step_context.kv_seq_length).tolist()) + setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill) return step_context diff --git a/lmdeploy/pytorch/engine/devices/dipu.py b/lmdeploy/pytorch/engine/devices/dipu.py deleted file mode 100644 index d2cc9c4243..0000000000 --- a/lmdeploy/pytorch/engine/devices/dipu.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_device_utils import BaseDeviceUtils - - -class DIPUDeviceUtils(BaseDeviceUtils): - - device = 'dipu' - - @classmethod - def update_step_context(cls, step_context): - """update step context.""" - raise NotImplementedError('`update_step_context` of ' - f'<{cls}> not implemented.') diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index bd207a1ecb..8ab92e0158 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from ..dipu import (apply_rotary_pos_emb, fill_kv_cache, fused_rotary_emb, - multinomial_sampling, paged_attention_fwd, rms_norm) +from ..default import multinomial_sampling +from .apply_rotary_pos_emb import apply_rotary_pos_emb +from .fill_kv_cache import fill_kv_cache +from .fused_rotary_emb import fused_rotary_emb +from .moe_gating_topk_softmax import moe_gating_topk_softmax +from .pagedattention import paged_attention_fwd +from .rms_norm import rms_norm __all__ = [ 'rms_norm', @@ -8,5 +13,6 @@ 'fused_rotary_emb', 'fill_kv_cache', 'paged_attention_fwd', + 'moe_gating_topk_softmax', 'multinomial_sampling', ] diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py new file mode 100644 index 0000000000..4a4039c44d --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def apply_rotary_pos_emb( + query_states: Tensor, + key_states: Tensor, + cos: Tensor, + sin: Tensor, + position_ids: Tensor, + position_ids_1d: Tensor, + q_embed=None, + k_embed=None, + context=None, +): + bs, head, dim = query_states.shape + num_kv_heads = key_states.shape[1] + query_states_reshaped = query_states.reshape(1, bs, head, dim) + key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + if len(cos.shape) == 3 and len(sin.shape) == 3: + cos = cos[:, position_ids_1d].view(1, bs, 1, -1) + sin = sin[:, position_ids_1d].view(1, bs, 1, -1) + elif len(cos.shape) == 2 and len(sin.shape) == 2: + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + else: + raise RuntimeError('Cannot handle cos/sin shape dims!') + + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + query_states, key_states = ext_ops.apply_rotary_pos_emb( + query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, + None, None) + query_states = query_states.view(bs, head, dim) + key_states = key_states.view(bs, num_kv_heads, dim) + if q_embed is None: + q_embed = query_states + else: + q_embed.copy_(query_states) + if k_embed is None: + k_embed = key_states + else: + k_embed.copy_(key_states) + return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py new file mode 100644 index 0000000000..333e500532 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def fill_kv_cache( + key_states: Tensor, + value_states: Tensor, + key_caches: Tensor, + value_caches: Tensor, + q_start_loc: Tensor, + q_seq_length: Tensor, + kv_seq_length: Tensor, + max_q_seq_length: int, + block_offsets: Tensor, + context: None, +): + """fill key/value state to cache for paged attention.""" + ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, + context.kv_start_indices) diff --git a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py new file mode 100644 index 0000000000..03fa2910af --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def fused_rotary_emb( + query_states: Tensor, + key_states: Tensor, + position_ids: torch.LongTensor, + inv_freq: Tensor, + scaling_factor: float, + out_q: Tensor = None, + out_k: Tensor = None, + context=None, +): + batch, seqlen, head, dim = query_states.shape + num_kv_heads = key_states.shape[-2] + query_states_reshaped = query_states.view(batch, seqlen, head, dim) + key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) + position_ids = position_ids.squeeze(0).unsqueeze(-1) + pos_freq = position_ids / scaling_factor * inv_freq + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None) + if out_q is None: + out_q = query_states + else: + out_q.copy_(query_states) + if out_k is None: + out_k = key_states + else: + out_k.copy_(key_states) + return out_q, out_k diff --git a/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..87b5ad1b39 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def moe_gating_topk_softmax(router_logits: Tensor, topk: int): + routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax( + router_logits, topk) + return routing_weights.to(torch.float32), selected_experts.to(torch.int64) diff --git a/lmdeploy/pytorch/kernels/ascend/pagedattention.py b/lmdeploy/pytorch/kernels/ascend/pagedattention.py new file mode 100644 index 0000000000..aa2609e476 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/pagedattention.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def prefill_attention( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + block_size: int, + kv_cache_len: int, + context=None, +): + num_q_heads, dim = query_states.shape[1:3] + num_kv_heads = value_states.shape[1] + + if context.is_unpaged_prefill: + ext_ops.prefill_attention( + query_states, + key_states, + value_states, + q_start_loc, + q_seq_len, + context.max_q_seq_length, + num_q_heads, + num_kv_heads, + attn_mask=context.attention_mask, + attn_output=attn_output, + ) + else: + key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + ext_ops.paged_prefill_attention( + query_states, + key_cache, + value_cache, + block_offsets, + block_size, + q_start_loc, + q_seq_len, + kv_seq_len, + num_q_heads, + num_kv_heads, + attn_mask=context.attention_mask, + attn_output=attn_output, + ) + + +def paged_decode_attention(q, k_cache, v_cache, attn_output, kv_seq_len, + max_kv_seq_len, block_offsets, block_size): + num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] + ext_ops.paged_decode_attention( + q, + k_cache, + v_cache, + block_offsets, + block_size, + kv_seq_len, + max_kv_seq_len, + num_q_heads, + num_kv_heads, + attn_output=attn_output.view(q.shape), + ) + + +def paged_attention_fwd( + query_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: Tensor, + value_cache: Tensor, + attn_output: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_seqlens: Tensor, + max_seqlen: int, + window_size: int = 1, + context=None, +): + is_decoding = query_states.shape[-3] == q_seqlens.size(0) + block_num, block_size, head, dim = key_cache.size() + kv_cache_len = block_num * block_size + k = key_cache.reshape(block_num * block_size, head, dim) + v = value_cache.reshape(block_num * block_size, head, dim) + if not is_decoding: + prefill_attention( + query_states, + key_states, + value_states, + attn_output, + k, + v, + block_offsets, + q_start_loc, + q_seqlens, + kv_seqlens, + block_size, + kv_cache_len, + context=context, + ) + else: + paged_decode_attention( + query_states, + k, + v, + attn_output, + kv_seqlens, + context.max_kv_seq_length, + block_offsets, + block_size, + ) diff --git a/lmdeploy/pytorch/kernels/ascend/rms_norm.py b/lmdeploy/pytorch/kernels/ascend/rms_norm.py new file mode 100644 index 0000000000..57b2f26c21 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/rms_norm.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def rms_norm(hidden_states: Tensor, + weight: Tensor, + eps: float = 1e-6, + out: Tensor = None): + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps) + if out is None: + out = rms_norm_out + else: + out.copy_(rms_norm_out) + return out diff --git a/lmdeploy/pytorch/kernels/dipu/__init__.py b/lmdeploy/pytorch/kernels/dipu/__init__.py deleted file mode 100644 index 65ebc8cec1..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from ..default import multinomial_sampling -from .apply_rotary_pos_emb import apply_rotary_pos_emb -from .fill_kv_cache import fill_kv_cache -from .fused_rotary_emb import fused_rotary_emb -from .pagedattention import paged_attention_fwd -from .rms_norm import rms_norm - -__all__ = [ - 'rms_norm', - 'apply_rotary_pos_emb', - 'fused_rotary_emb', - 'fill_kv_cache', - 'paged_attention_fwd', - 'multinomial_sampling', -] diff --git a/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py deleted file mode 100644 index 559cf8afba..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -from torch import Tensor - - -def apply_rotary_pos_emb( - query_states: Tensor, - key_states: Tensor, - cos: Tensor, - sin: Tensor, - position_ids: Tensor, - position_ids_1d: Tensor, - q_embed=None, - k_embed=None, - context=None, -): - bs, head, dim = query_states.shape - numKeyValueHeads = key_states.shape[1] - query_states = query_states.reshape(bs, head * dim) - key_states = key_states.reshape(bs, numKeyValueHeads * dim) - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - if cos.dim() == 3: - cos = cos[:, position_ids_1d].view(1, bs, 1, -1) - sin = sin[:, position_ids_1d].view(1, bs, 1, -1) - elif cos.dim() == 2: - cos = cos[position_ids_1d].view(1, bs, 1, -1) - sin = sin[position_ids_1d].view(1, bs, 1, -1) - else: - raise RuntimeError(f'Unsupport cos dim: {cos.dim()}') - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - ext.rotary_embedding_v2(query_states, key_states, context.cos, context.sin, - dim) - return query_states.view(bs, head, - dim), key_states.view(bs, numKeyValueHeads, dim) diff --git a/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py deleted file mode 100644 index f51b851185..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -from torch import Tensor - - -def fill_kv_cache( - key_states: Tensor, - value_states: Tensor, - key_caches: Tensor, - value_caches: Tensor, - q_start_loc: Tensor, - q_seq_length: Tensor, - kv_seq_length: Tensor, - max_q_seq_length: int, - block_offsets: Tensor, - context: None, -): - """fill key/value state to cache for paged attention.""" - dest_index_copy_kv(key_states, context.kv_start_indices, key_caches) - dest_index_copy_kv(value_states, context.kv_start_indices, value_caches) - - -def dest_index_copy_kv(states, dest_loc, caches): - block_num, block_size, head, dim = caches.size() - caches_tmp = caches.view(block_num * block_size, head, dim) - ext.dest_index_copy_kv(states, dest_loc, caches_tmp) - caches[:] = caches_tmp.view(block_num, block_size, head, dim) diff --git a/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py deleted file mode 100644 index 2a67a24516..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def fused_rotary_emb( - query_states: Tensor, - key_states: Tensor, - position_ids: torch.LongTensor, - inv_freq: Tensor, - scaling_factor: float, - out_q: Tensor = None, - out_k: Tensor = None, - context=None, -): - _, bs, head, dim = query_states.shape - _, _, numKeyValueHeads, _ = key_states.shape - query_states = query_states.view(bs, head * dim) - key_states = key_states.view(bs, numKeyValueHeads * dim) - position_ids = position_ids.squeeze(0).unsqueeze(-1) - pos_freq = position_ids / scaling_factor * inv_freq - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = (torch.cos(pos_freq).view(position_ids.shape[0], 1, - -1).repeat(1, 1, - 2).to(query_states.dtype)) - sin = (torch.sin(pos_freq).view(position_ids.shape[0], 1, - -1).repeat(1, 1, - 2).to(query_states.dtype)) - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - ext.rotary_embedding_v2(query_states, key_states, context.cos, context.sin, - dim) - query_states = query_states.view(1, bs, head, dim) - key_states = key_states.view(1, bs, numKeyValueHeads, dim) - return query_states, key_states diff --git a/lmdeploy/pytorch/kernels/dipu/pagedattention.py b/lmdeploy/pytorch/kernels/dipu/pagedattention.py deleted file mode 100644 index 9304ec0a35..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/pagedattention.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def flash_context_attention( - query_states: Tensor, - key_states: Tensor, - value_states: Tensor, - attn_output: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: list, - kv_seqlens: list, - block_size: int, - kv_cache_len: int, - context=None, -): - batch, head, dim = ( - q_start_loc.shape[0], - query_states.shape[1], - query_states.shape[2], - ) - numKeyValueHeads = value_states.shape[1] - assert key_states.shape[1] == value_states.shape[1] - for i in range(batch): - start = q_start_loc[i] - end = start + q_seqlens[i] - single_seqlen = int(end - start) - single_q = query_states[start:end].view(1, single_seqlen, -1) - single_k = key_states[start:end].reshape(1, single_seqlen, -1) - single_v = value_states[start:end].reshape(1, single_seqlen, -1) - single_out = attn_output[start:end, :].view(1, single_seqlen, -1) - mask = context.attention_mask[i] - if q_seqlens[i] == kv_seqlens[i]: - ext.prompt_flash_attention( - single_out, - single_q, - single_k, - single_v, - mask, - [kv_seqlens[i]], - kv_seqlens[i], - head, - numKeyValueHeads, - dim, - ) - else: - key_cache = key_cache.reshape(1, kv_cache_len, - numKeyValueHeads * dim) - value_cache = value_cache.reshape(1, kv_cache_len, - numKeyValueHeads * dim) - for j in range(q_seqlens[i]): - single_q = query_states[start + j:start + j + 1].view(1, 1, -1) - single_out = attn_output[start + j:start + j + 1].view( - 1, 1, -1) - ext.paged_attention( - single_out, - single_q, - key_cache, - value_cache, - mask[j:j + 1], - [kv_seqlens[i]], - head, - numKeyValueHeads, - dim, - block_offsets[i:i + 1], - block_size, - ) - - -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seqlens, - block_table, block_size): - numKeyValueHeads = k_cache.shape[1] - assert k_cache.shape[1] == v_cache.shape[1] - bs, head, dim = q.shape - kv_cache_len = k_cache.shape[0] - q = q.reshape(bs, 1, head * dim) - k_cache = k_cache.reshape(1, kv_cache_len, numKeyValueHeads * dim) - v_cache = v_cache.reshape(1, kv_cache_len, numKeyValueHeads * dim) - ext.paged_attention( - attn_output.view(q.shape), - q, - k_cache, - v_cache, - None, - kv_seqlens, - head, - numKeyValueHeads, - dim, - block_table, - block_size, - ) - - -def paged_attention_fwd( - query_states: Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - key_cache: Tensor, - value_cache: Tensor, - attn_output: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - kv_seqlens: Tensor, - max_seqlen: int, - window_size: int = 1, - context=None, -): - is_decoding = query_states.shape[-3] == q_seqlens.size(0) - block_num, block_size, head, dim = key_cache.size() - kv_cache_len = block_num * block_size - k = key_cache.reshape(block_num * block_size, head, dim) - v = value_cache.reshape(block_num * block_size, head, dim) - if not is_decoding: - flash_context_attention( - query_states, - key_states, - value_states, - attn_output, - k, - v, - block_offsets.to(torch.int32), - q_start_loc, - q_seqlens.tolist(), - kv_seqlens.tolist(), - block_size, - kv_cache_len, - context=context, - ) - else: - paged_token_attention( - query_states, - k, - v, - attn_output, - kv_seqlens.tolist(), - block_offsets.to(torch.int32), - block_size, - ) diff --git a/lmdeploy/pytorch/kernels/dipu/rms_norm.py b/lmdeploy/pytorch/kernels/dipu/rms_norm.py deleted file mode 100644 index 8dbcf91ca2..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/rms_norm.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6): - output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty(inv_rms_shape, - dtype=torch.float32, - device=hidden_states.device) - ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, None, - eps) - return output diff --git a/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..b8a55d4225 --- /dev/null +++ b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dispatcher import FunctionDispatcher + +moe_gating_topk_softmax = FunctionDispatcher( + 'moe_gating_topk_softmax').make_caller()