Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,10 @@
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip
from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip
from vllm.model_executor.models.qwen3_next import ( # isort: skip
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
fused_gdn_gating)


class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
Expand Down Expand Up @@ -429,17 +426,16 @@ class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):

def __init__(
self,
config: Qwen3NextConfig,
vllm_config: VllmConfig,
layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
nn.Module.__init__(self)
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config

self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
Expand Down Expand Up @@ -468,12 +464,8 @@ def __init__(
if (self.layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3NextMLP(
hidden_size=config.hidden_size,
Expand All @@ -493,14 +485,14 @@ def __init__(
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )

Expand All @@ -511,13 +503,8 @@ class CustomQwen3NextModel(Qwen3NextModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts

Expand All @@ -534,14 +521,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def get_layer(prefix: str):
return CustomQwen3NextDecoderLayer(
config,
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix,
enable_eplb=enable_eplb,
)

self.start_layer, self.end_layer, self.layers = make_layers(
Expand Down
Loading