Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/user_guide/feature_guide/graph_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa
There are two kinds for graph mode supported by vLLM Ascend:

- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
- **XliteGraph**: This is the openeuler xlite graph mode. In v0.11.0, only Llama, Qwen dense series models, and Qwen3-vl are supported.
- **XliteGraph**: This is the openeuler xlite graph mode. In v0.11.0, only Llama, Qwen dense series models, Qwen MoE series models, and Qwen3-vl are supported.

## Using ACLGraph

Expand All @@ -38,7 +38,7 @@ vllm serve Qwen/Qwen2-7B-Instruct

## Using XliteGraph

If you want to run Llama, Qwen dense series models, or Qwen3-vl with xlite graph mode, please install xlite, and set xlite_graph_config.
If you want to run Llama, Qwen dense series models, Qwen MoE series models, or Qwen3-vl with xlite graph mode, please install xlite, and set xlite_graph_config.

```bash
pip install xlite
Expand All @@ -61,7 +61,7 @@ Online example:
vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}'
```

You can find more details abort xlite [here](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md)
You can find more details about xlite [here](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md)

## Fallback to the Eager Mode

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ pytest_mock
msserviceprofiler>=1.2.2
mindstudio-probe>=8.3.0
arctic-inference==0.1.1
xlite==0.1.0rc0
xlite==0.1.0rc1
uc-manager
6 changes: 3 additions & 3 deletions tests/e2e/singlecard/test_xlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
n=1,
))

CASE_FULL_DECODE_ONLY = LLMTestCase(
CASE_FULL = LLMTestCase(
model="Qwen/Qwen3-0.6B",
prompts=[
"Hello, my name is", "The president of the United States is",
Expand All @@ -57,7 +57,7 @@
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the",
' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president',
' Paris. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of China is Beijing. The capital of Japan is Tokyo. The capital',
" not just about the technology itself, but about how we use it to solve real-world problems. As AI continues to evolve, it's important to consider the ethical"
" not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and"
],
sampling_params=SamplingParams(
max_tokens=32,
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_models_with_xlite_decode_only(cur_case: LLMTestCase):
golden_answers=cur_case.golden_answers)


@pytest.mark.parametrize("cur_case", [CASE_FULL_DECODE_ONLY])
@pytest.mark.parametrize("cur_case", [CASE_FULL])
def test_models_with_xlite_full_mode(cur_case: LLMTestCase):
runner_kwargs = {
"model_name": cur_case.model,
Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

from vllm.config.compilation import CUDAGraphMode

if ascend_config.xlite_graph_config.enabled and ascend_config.xlite_graph_config.full_mode:
logger.info("ACLGraph is disabled under xlite full mode")
enforce_eager = True
model_config.enforce_eager = True
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

if enforce_eager:
logger.info("Compilation disabled, using eager mode by default")
compilation_config.mode = CompilationMode.NONE
Expand Down
169 changes: 120 additions & 49 deletions vllm_ascend/xlite/xlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (get_ep_group,
get_tensor_model_parallel_world_size,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime,
ScoringFuncSoftmax)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
Expand All @@ -47,6 +49,55 @@ def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)

freq_cis = self._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)

return (xlite_model, freq_cis, config.hidden_size, dtype)

def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, "text_config"):
hf_config = hf_config.text_config
config = ModelConfig()
config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size
config.n_layers = hf_config.num_hidden_layers
config.n_heads = hf_config.num_attention_heads
config.n_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, "head_dim"):
config.head_dim = hf_config.head_dim
else:
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
config.rope_head_dim = config.head_dim
config.norm_eps = hf_config.rms_norm_eps
config.rope_theta = hf_config.rope_theta
config.softmax_scale = config.head_dim**-0.5
config.n_dense_layers = hf_config.num_hidden_layers
config.intermediate_size = hf_config.intermediate_size
config.def_tp_size = get_tensor_model_parallel_world_size()
config.def_dp_size = 1
config.moe_ep_size = 1
config.moe_tp_size = 1

config.attn_type = AttnMHA
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len
config.max_m = scheduler_config.max_num_batched_tokens
config.max_batch_size = max_batch_size
config.max_seq_len = max_seq_len
config.block_size = vllm_config.cache_config.block_size
return config

def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
params_dict = dict(runnable.named_parameters())

if hasattr(runnable, "language_model"):
Expand All @@ -56,7 +107,6 @@ def initialize(
layers = runnable.model.layers
model_prefix = ""

config = self._build_model_config(vllm_config)
xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix +
"model.embed_tokens.weight")
Expand All @@ -79,8 +129,14 @@ def initialize(
]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers
if hasattr(layer.mlp, "gate_up_proj")
and layer.mlp.gate_up_proj.weight is not None
]
xlite_model.mlp_down = [
layer.mlp.down_proj.weight for layer in layers
if hasattr(layer.mlp, "down_proj")
and layer.mlp.down_proj.weight is not None
]
xlite_model.mlp_down = [layer.mlp.down_proj.weight for layer in layers]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias")
Expand Down Expand Up @@ -108,50 +164,7 @@ def initialize(
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm

rank = torch.distributed.get_rank()
xlite_model.init(config, rank)

freq_cis = self._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)

return (xlite_model, freq_cis, config.hidden_size, dtype)

def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, "text_config"):
hf_config = hf_config.text_config
config = ModelConfig()
config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size
config.n_layers = hf_config.num_hidden_layers
config.n_heads = hf_config.num_attention_heads
config.n_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, "head_dim"):
config.head_dim = hf_config.head_dim
else:
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
config.rope_head_dim = config.head_dim
config.norm_eps = hf_config.rms_norm_eps
config.rope_theta = hf_config.rope_theta
config.softmax_scale = config.head_dim**-0.5
config.n_dense_layers = hf_config.num_hidden_layers
config.intermediate_size = hf_config.intermediate_size
config.def_tp_size = get_tensor_model_parallel_world_size()
config.def_dp_size = 1
config.moe_ep_size = 1
config.moe_tp_size = 1

config.attn_type = AttnMHA
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len
config.max_m = scheduler_config.max_num_batched_tokens
config.max_batch_size = max_batch_size
config.max_seq_len = max_seq_len
config.block_size = vllm_config.cache_config.block_size
return config
return xlite_model

def _precompute_freqs_cis(self,
dim: int,
Expand All @@ -168,6 +181,62 @@ def _precompute_freqs_cis(self,
return freq_cis.to(device='npu')


class QwenMoeXliteModel(LlamaXliteModel):

def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
architecture = vllm_config.model_config.architectures[0]
raise ValueError(
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)

freq_cis = super()._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)

return (xlite_model, freq_cis, config.hidden_size, dtype)

def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
config = super()._build_model_config(vllm_config)
hf_config = vllm_config.model_config.hf_text_config
ep_group = get_ep_group()
config.n_layers = hf_config.max_window_layers
config.n_dense_layers = 0
config.n_routed_experts = hf_config.num_experts
config.n_shared_experts = 0
config.n_act_experts = hf_config.num_experts_per_tok
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
config.experts_weight_transpose = True
config.moe_intermediate_size = hf_config.moe_intermediate_size
config.norm_topk_prob = hf_config.norm_topk_prob
config.scoring_func = ScoringFuncSoftmax
return config

def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
xlite_model = super()._build_model(runnable, vllm_config, config)
layers = runnable.model.layers
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
xlite_model.re_up_gate = [
layer.mlp.experts.w13_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]
xlite_model.re_down = [
layer.mlp.experts.w2_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]

return xlite_model


def xlite_model_init(
runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
Expand All @@ -176,6 +245,7 @@ def xlite_model_init(
"Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel,
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
"Qwen3MoeForCausalLM": QwenMoeXliteModel,
}

architecture = vllm_config.model_config.architectures[0]
Expand All @@ -197,7 +267,8 @@ def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank
self.xlite_rt = Runtime(local_rank, 0, rank,
get_tensor_model_parallel_world_size())
get_tensor_model_parallel_world_size(),
vllm_config.parallel_config.data_parallel_size)

(self.xlite_model, self.freq_cis, hidden_size,
dtype) = xlite_model_init(runnable, vllm_config)
Expand Down