Skip to content

Commit

Permalink
support ascend mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen authored and jinminxi104 committed Aug 20, 2024
1 parent 90ed8fb commit f170fb8
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 19 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_rotary_emb import fused_rotary_emb
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .paged_attention_fwd import paged_attention_fwd
from .rms_norm import rms_norm

Expand All @@ -12,5 +13,6 @@
'fused_rotary_emb',
'fill_kv_cache',
'paged_attention_fwd',
'moe_gating_topk_softmax',
'multinomial_sampling',
]
6 changes: 2 additions & 4 deletions lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ def apply_rotary_pos_emb(
setattr(context, 'sin', sin)
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin,
None, None, None
)
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None, None)
if q_embed is None:
q_embed = query_states
else:
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def fill_kv_cache(
context: None,
):
"""fill key/value state to cache for paged attention."""
ext_ops.fill_kv_cache(key_states, value_states, key_caches,
value_caches, context.kv_start_indices)
ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches,
context.kv_start_indices)
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def fused_rotary_emb(
position_ids = position_ids.squeeze(0).unsqueeze(-1)
pos_freq = position_ids / scaling_factor * inv_freq
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1)
.repeat(1, 1, 1, 2).to(query_states.dtype))
sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1)
.repeat(1, 1, 1, 2).to(query_states.dtype))
cos = (torch.cos(pos_freq).view(batch, seqlen, 1,
-1).repeat(1, 1, 1,
2).to(query_states.dtype))
sin = (torch.sin(pos_freq).view(batch, seqlen, 1,
-1).repeat(1, 1, 1,
2).to(query_states.dtype))
if context:
setattr(context, 'cos', cos)
setattr(context, 'sin', sin)
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import infer_ext.ops as ext_ops
import torch
from torch import Tensor


def moe_gating_topk_softmax(router_logits: Tensor, topk: int):
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(
router_logits, topk)
return routing_weights.to(torch.float32), selected_experts.to(torch.int64)
21 changes: 12 additions & 9 deletions lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def flash_context_attention(
):
num_q_heads, dim = query_states.shape[1:3]
num_kv_heads = value_states.shape[1]
batch = q_start_loc.shape[0]
batch = q_start_loc.shape[0]

for i in range(batch):
if torch.equal(q_seq_len[i], kv_seq_len[i]):
Expand All @@ -30,30 +30,32 @@ def flash_context_attention(
query_states,
key_states,
value_states,
q_start_loc[i:i+1],
q_seq_len[i:i+1],
q_start_loc[i:i + 1],
q_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
context.attention_mask[i:i+1],
context.attention_mask[i:i + 1],
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len,
num_kv_heads * dim)
ext_ops.paged_prefill_attention(
attn_output,
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc[i:i+1],
q_seq_len[i:i+1],
kv_seq_len[i:i+1],
q_start_loc[i:i + 1],
q_seq_len[i:i + 1],
kv_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
context.attention_mask[i:i+1],
context.attention_mask[i:i + 1],
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_offsets, block_size):
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
Expand All @@ -69,6 +71,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
num_kv_heads,
)


def paged_attention_fwd(
query_states: Tensor,
key_states: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dispatcher import FunctionDispatcher

moe_gating_topk_softmax = FunctionDispatcher(
'moe_gating_topk_softmax').make_caller()
196 changes: 196 additions & 0 deletions lmdeploy/pytorch/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
from ..kernels.fused_moe import fused_moe
from ..kernels.moe_gating_topk_softmax import moe_gating_topk_softmax
from ..weight_loader.dist_utils import (colwise_parallelize_linear,
rowwise_parallelize_linear)

Expand Down Expand Up @@ -156,6 +157,157 @@ def forward(
)


class PatchedMixtralAttentionAscend(nn.Module):
"""Rewrite module of MixtralAttention."""

def _load_weights(self, loader, rank: int, world_size: int,
device: torch.device):
"""load weights."""
for mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear(getattr(self, mod_name),
loader,
rank=rank,
world_size=world_size,
prefix=mod_name)
rowwise_parallelize_linear(self.o_proj,
loader,
rank=rank,
world_size=world_size,
prefix='o_proj')

@classmethod
def _distribute_output_fn(cls, outputs, **kwargs):
"""Distribution output hook."""
try:
dist.all_reduce(outputs[0])
except Exception as e:
print(e)
return outputs

def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""default rewrite."""

context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length

num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
hidden_size = num_heads * self.head_dim

def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

return query_states, key_states, value_states

def __rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
if not hasattr(context, '_cos'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
context._cos = cos
context._sin = sin
else:
cos = context._cos
sin = context._sin
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
context.position_ids_1d,
q_embed=query_states,
k_embed=key_states,
context=context)
return query_states, key_states, value_states

query_states, key_states, value_states = __qkv_proj(hidden_states)

query_states = query_states.view(-1, num_heads, self.head_dim)
key_states = key_states.view(-1, num_kv_heads, self.head_dim)
value_states = value_states.view(-1, num_kv_heads, self.head_dim)

query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
# fill kv cache
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
context=context,
)
# page attention
attn_output = query_states
window_size = self.config.sliding_window or -1
paged_attention_fwd(
query_states,
key_states,
value_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
context=context,
)

attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)

attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of MistralAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
attention_mask=attention_mask,
world_size=world_size,
)


class PatchedMixtralBLockSparseTop2MLP(nn.Module):

def _load_weights(self, loader, rank: int, world_size: int,
Expand Down Expand Up @@ -255,6 +407,50 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return out_states, router_logits


class PatchedMixtralSparseMoeBlockAscend(nn.Module):

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""rewrite moe forward."""

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)

routing_weights, selected_experts = moe_gating_topk_softmax(
router_logits, self.top_k)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device)

expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

if top_x.shape[0] == 0:
continue

top_x_list = top_x.tolist()
idx_list = idx.tolist()

current_state = hidden_states[None,
top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(
current_state) * routing_weights[top_x_list, idx_list, None]

final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits


class PatchedMixtralModel(nn.Module):

def _continuous_batching_forward(
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,15 @@
'modeling_internlm2.InternLM2FlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend',
})

# ascend mixtral
ASCEND_MODULE_MAP.update({
'transformers.models.mixtral.modeling_mixtral.MixtralAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend',
'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend',
'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend',
'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlockAscend', # noqa: E501
})

0 comments on commit f170fb8

Please sign in to comment.