Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class AscendMLAMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.

with_prefill_across_dp: bool = False

# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
Expand Down Expand Up @@ -260,6 +262,10 @@ def build_dummy(self, num_reqs: int,
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
Expand All @@ -278,15 +284,21 @@ def build_dummy(self, num_reqs: int,
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
block_tables=block_table,
)

def build(self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1) -> AscendMLAMetadata:
def build(
self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1,
with_prefill_across_dp: bool = False,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs

# Note(simon): be careful about the CPU <> GPU memory movement in this
Expand Down Expand Up @@ -388,6 +400,7 @@ def build(self,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
with_prefill_across_dp=with_prefill_across_dp,
)


Expand Down Expand Up @@ -621,7 +634,7 @@ def exec_kv(
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
Expand All @@ -643,7 +656,7 @@ def rope_single(
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)

def _forward_decode(
Expand Down Expand Up @@ -766,6 +779,7 @@ def forward(
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]

decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
Expand Down
77 changes: 49 additions & 28 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def __init__(
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group

self.params_dtype = torch.get_default_dtype()

self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -228,52 +236,65 @@ def forward(
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
num_tokens, hidden_dim = hidden_states.shape
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp

num_tokens, hidden_size = hidden_states.shape

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

if self.tp_size > 1:
# pass
num_tokens, hidden_size = hidden_states.shape
if num_tokens < self.tp_size:
target_size = self.tp_size
new_hidden_states = torch.empty([target_size, hidden_size],
dtype=hidden_states.dtype,
device=hidden_states.device)
new_hidden_states[:num_tokens] = hidden_states
hidden_states = new_hidden_states
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
local_hidden_states = chunk_hidden_states[self.tp_rank]
else:
local_hidden_states = hidden_states
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.enable_graph_mode:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
if num_padding_tokens > 0:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, num_padding_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
hidden_states = chunk_hidden_states[self.tp_rank]

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(local_hidden_states)
router_logits, _ = self.gate(hidden_states)

router_hidden_states = self.experts(
hidden_states=local_hidden_states,
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor

if self.tp_size > 1:
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < self.tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
else:
final_hidden_states = router_hidden_states
if self.enable_graph_mode:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states,
hidden_states, self.tp_group)
hidden_states = final_hidden_states
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
else:
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_padding_tokens > 0:
hidden_states = hidden_states[:-num_padding_tokens]

if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
hidden_states = hidden_states + shared_output

return final_hidden_states.view(num_tokens, hidden_dim)
return hidden_states.view(num_tokens, hidden_size)


class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
Expand Down
91 changes: 65 additions & 26 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,12 @@ def __init__(self, moe: MoEConfig = None):
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size

self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

try:
device_group = ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
Expand Down Expand Up @@ -664,7 +670,7 @@ def apply(
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif get_ep_group().world_size == 1:
elif self.enable_graph_mode or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down Expand Up @@ -750,26 +756,20 @@ def __init__(
self.expert_map = None
self.activation = activation

if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)

self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
# Create a tensor of size num_experts filled with -1
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)

else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group

self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
self.moe_parallel_config.ep_rank = 0
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
self.moe_parallel_config.ep_size = 1
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
Expand Down Expand Up @@ -807,8 +807,15 @@ def __init__(
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size

self.ep_group = get_ep_group()
self.quant_method.create_weights(layer=self, **moe_quant_params)

self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -822,11 +829,28 @@ def forward(self,
else:
real_top_k = self.top_k

if VLLM_ENABLE_MC2 and not is_prefill:
...
# MC2 ag/rs broadcast/all_reduce
# prefill_req x x √
# decode_req √ x √
# graph_mode √ √ x
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
if USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
elif self.enable_graph_mode and not is_prefill:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
else:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
Expand All @@ -843,11 +867,26 @@ def forward(self,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)

if VLLM_ENABLE_MC2 and not is_prefill:
...
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
if USING_LCCL_COM: # type: ignore
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.enable_graph_mode and not is_prefill:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
else:
hidden_states = get_ep_group().combine(hidden_states)

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

return final_hidden_states
return hidden_states
4 changes: 3 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# Calculate expert parallel size based on world size
parallel_config.expert_parallel_size = (
parallel_config.world_size //
parallel_config.world_size_across_dp //
parallel_config.expert_tensor_parallel_size)

if model_config is None:
Expand Down Expand Up @@ -167,6 +167,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
raise NotImplementedError(
"enable_graph_mode only works with deepseek model."
)
# Set compilation level to NO_COMPILATION to disable ACL Graph
compilation_config.level = CompilationLevel.NO_COMPILATION

elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
model_type = model_config.hf_config.model_type
Expand Down
9 changes: 8 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.distributed as dist
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import GroupCoordinator

import vllm_ascend.envs as envs_ascend
Expand Down Expand Up @@ -508,6 +509,12 @@ def __init__(self):

self.ep_group = get_ep_group()

self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)

try:
device_group = self.ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
Expand Down Expand Up @@ -629,7 +636,7 @@ def apply(
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif self.ep_group.world_size == 1:
elif self.enable_graph_mode or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
Expand Down
Loading