Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
51f9418
add mxfp8 moe quantization
Feb 10, 2026
9912d19
Code naming optimization
wangyao-i Feb 11, 2026
facc18b
cleancode
Eric-dot Feb 11, 2026
9c2d244
cleancode
Eric-dot Feb 11, 2026
d67aa97
fix function missing, kwargs mismatching and missed attribute
linfeng-yuan Feb 11, 2026
4df796b
Add explicit MXFP8 quant type and keep prepare-stage quantization dis…
linfeng-yuan Feb 11, 2026
a0af93e
Unify dispatch quantization via with_quant and remove fp8_comm toggle
linfeng-yuan Feb 11, 2026
38729b7
remove unused variables
linfeng-yuan Feb 12, 2026
6817418
Merge branch 'main' into main
Eric-dot Feb 24, 2026
9ddfd96
Add mxfp dtype compatibility layer with explicit runtime capability c…
linfeng-yuan Feb 24, 2026
2e69836
cleancode and leave note for removing the compatiblity shim
linfeng-yuan Feb 24, 2026
f41f501
change importing file
linfeng-yuan Feb 24, 2026
fbcb75d
remove unstable ccu setting and unnecessary global_bs setting
linfeng-yuan Feb 24, 2026
35adc91
Refactor MoE token dispatch call path by unifying shared kwargs const…
linfeng-yuan Feb 24, 2026
59774ae
clean up A5 MXFP8 integration and MC2 quantization control
linfeng-yuan Feb 26, 2026
314aee2
mxfp8 moe_mlp quant_apply_mlp adaptation
Eric-dot Feb 26, 2026
5fd1a15
format CI
linfeng-yuan Feb 26, 2026
1fa339b
mxfp npu_moe_init_routing adaptation
Eric-dot Feb 27, 2026
bfb9cd0
Refine MoE MXFP integration and gmm2 adaptor flow
linfeng-yuan Feb 27, 2026
a2a7698
format CI
linfeng-yuan Feb 27, 2026
923ccbe
refine gmm2 in DeviceAdaptor and modify commnets
linfeng-yuan Feb 27, 2026
639c1c3
remove redudant dynamic quant calling
linfeng-yuan Feb 27, 2026
98bf2e2
use DeviceOperator directly
linfeng-yuan Feb 27, 2026
ddf4396
fix: narrow DeviceOperator type for mypy
linfeng-yuan Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
elif soc_version in {AscendDeviceType._310P}:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A5}:
if num_tokens <= mc2_tokens_capacity and vllm_config.parallel_config.world_size_across_dp > 1:
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLTOALL
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
327 changes: 324 additions & 3 deletions vllm_ascend/device/device_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

import torch
import torch_npu

from vllm_ascend.quantization.mxfp_compat import (
FLOAT4_E2M1FN_X2_DTYPE,
FLOAT8_E8M0FNU_DTYPE,
HIFLOAT8_DTYPE,
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type


Expand All @@ -28,6 +33,126 @@ def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping
)

@staticmethod
def npu_moe_init_routing(
hidden_states,
topk_ids,
*,
scale=None,
active_num: int,
expert_num: int,
expert_tokens_num_type: int = 1,
expert_tokens_num_flag: bool = True,
active_expert_range=None,
quant_mode: int = -1,
):
return torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)

@staticmethod
def npu_dynamic_quant(
hidden_states: torch.Tensor,
dynamic_scale: torch.Tensor | None = None,
*,
act_quant_type=torch.float8_e4m3fn,
use_mxfp_quant: bool = False,
):
if use_mxfp_quant:
raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.")

if dynamic_scale is None:
return torch_npu.npu_dynamic_quant(hidden_states)

return hidden_states, dynamic_scale

@staticmethod
def npu_grouped_matmul_swiglu_quant(
*,
x: torch.Tensor,
weight: torch.Tensor,
group_list: torch.Tensor,
weight_scale: torch.Tensor,
x_scale: torch.Tensor,
bias=None,
use_mxfp_quant: bool = False,
):
if use_mxfp_quant:
raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.")

return torch_npu.npu_grouped_matmul_swiglu_quant(
x=x,
weight=weight,
bias=bias,
group_list=group_list,
weight_scale=weight_scale,
x_scale=x_scale,
)

@staticmethod
def get_quant_gmm2_kwargs(
*,
input_dtype: torch.dtype,
act_quant_type,
weight_quant_type,
scale_type,
per_token_scale_type,
use_bf16: bool = True,
use_mxfp_quant: bool = False,
) -> dict:
if use_mxfp_quant:
raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.")

return {
"output_dtype": input_dtype if input_dtype in [torch.bfloat16, torch.float16] else torch.bfloat16,
}

@classmethod
def npu_grouped_matmul_gmm2(
cls,
*,
hidden_states: torch.Tensor,
weight: list[torch.Tensor] | torch.Tensor,
weight_scale: list[torch.Tensor] | torch.Tensor,
per_token_scale: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int,
input_dtype: torch.dtype,
act_quant_type,
weight_quant_type,
scale_type,
per_token_scale_type,
use_bf16: bool = True,
use_mxfp_quant: bool = False,
bias=None,
fallback_output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
if use_mxfp_quant:
raise RuntimeError("MXFP8 MoE quantization is only supported on Ascend A5.")

if fallback_output_dtype is None:
fallback_output_dtype = weight_scale[0].dtype if isinstance(weight_scale, list) else weight_scale.dtype
return torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=weight,
scale=weight_scale,
bias=bias,
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=fallback_output_dtype,
)[0]


class A5DeviceAdaptor(BaseDeviceAdaptor):
@classmethod
Expand All @@ -36,12 +161,208 @@ def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping
)

@staticmethod
def npu_moe_init_routing(
hidden_states,
topk_ids,
*,
scale=None,
active_num: int,
expert_num: int,
expert_tokens_num_type: int = 1,
expert_tokens_num_flag: bool = True,
active_expert_range=None,
quant_mode: int = -1,
):
return torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)

@staticmethod
def npu_dynamic_quant(
hidden_states: torch.Tensor,
dynamic_scale: torch.Tensor | None = None,
*,
act_quant_type=torch.float8_e4m3fn,
use_mxfp_quant: bool = False,
):
if not use_mxfp_quant:
return BaseDeviceAdaptor.npu_dynamic_quant(
hidden_states,
dynamic_scale,
act_quant_type=act_quant_type,
use_mxfp_quant=False,
)

if dynamic_scale is None:
return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)

if dynamic_scale.ndim == 2:
dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2)

return hidden_states, dynamic_scale

@staticmethod
def npu_grouped_matmul_swiglu_quant(
*,
x: torch.Tensor,
weight: torch.Tensor,
group_list: torch.Tensor,
weight_scale: torch.Tensor,
x_scale: torch.Tensor,
bias=None,
use_mxfp_quant: bool = False,
):
if not use_mxfp_quant:
return BaseDeviceAdaptor.npu_grouped_matmul_swiglu_quant(
x=x,
weight=weight,
group_list=group_list,
weight_scale=weight_scale,
x_scale=x_scale,
bias=bias,
use_mxfp_quant=False,
)

out, out_scale = torch_npu.npu_grouped_matmul_swiglu_quant_v2(
x=x,
weight=[weight],
group_list=group_list,
weight_scale=[weight_scale],
x_scale=x_scale,
dequant_mode=2,
quant_mode=2,
dequant_dtype=torch.float32,
quant_dtype=torch.float8_e4m3fn,
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
)
return out, out_scale, None

@staticmethod
def get_quant_gmm2_kwargs(
*,
input_dtype: torch.dtype,
act_quant_type,
weight_quant_type,
scale_type,
per_token_scale_type,
use_bf16: bool = True,
use_mxfp_quant: bool = False,
) -> dict:
if not use_mxfp_quant:
return BaseDeviceAdaptor.get_quant_gmm2_kwargs(
input_dtype=input_dtype,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_bf16=use_bf16,
use_mxfp_quant=False,
)

quant_dtypes = tuple(dtype for dtype in (FLOAT4_E2M1FN_X2_DTYPE, HIFLOAT8_DTYPE) if dtype is not None)
scale_dtypes = tuple(dtype for dtype in (FLOAT8_E8M0FNU_DTYPE,) if dtype is not None)

output_dtype = (
input_dtype
if input_dtype in [torch.bfloat16, torch.float16]
else (torch.bfloat16 if use_bf16 else torch.float16)
)

return {
"scale_dtype": scale_type if scale_type in scale_dtypes else None,
"per_token_scale_dtype": per_token_scale_type if per_token_scale_type in scale_dtypes else None,
"x_dtype": act_quant_type if act_quant_type in quant_dtypes else None,
"weight_dtype": weight_quant_type if weight_quant_type in quant_dtypes else None,
"output_dtype": output_dtype,
}

@classmethod
def npu_grouped_matmul_gmm2(
cls,
*,
hidden_states: torch.Tensor,
weight: list[torch.Tensor] | torch.Tensor,
weight_scale: list[torch.Tensor] | torch.Tensor,
per_token_scale: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int,
input_dtype: torch.dtype,
act_quant_type,
weight_quant_type,
scale_type,
per_token_scale_type,
use_bf16: bool = True,
use_mxfp_quant: bool = False,
bias=None,
fallback_output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
if not use_mxfp_quant:
return BaseDeviceAdaptor.npu_grouped_matmul_gmm2(
hidden_states=hidden_states,
weight=weight,
weight_scale=weight_scale,
per_token_scale=per_token_scale,
group_list=group_list,
group_list_type=group_list_type,
input_dtype=input_dtype,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_bf16=use_bf16,
use_mxfp_quant=False,
bias=bias,
fallback_output_dtype=fallback_output_dtype,
)

gmm2_kwargs = cls.get_quant_gmm2_kwargs(
input_dtype=input_dtype,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_bf16=use_bf16,
use_mxfp_quant=True,
)
output_dtype = gmm2_kwargs.pop("output_dtype")

if isinstance(weight, list) and len(weight) != 1:
raise ValueError(f"w2 must have a single tensor in MXFP path, but got {len(weight)}.")
if isinstance(weight_scale, list) and len(weight_scale) != 1:
raise ValueError(f"w2_scale must have a single tensor in MXFP path, but got {len(weight_scale)}.")
gmm2_weight = weight if isinstance(weight, list) else [weight]
gmm2_scale = weight_scale if isinstance(weight_scale, list) else [weight_scale]

return torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=gmm2_weight,
scale=gmm2_scale,
bias=bias,
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
**gmm2_kwargs,
)[0]


def get_device_adaptor():
def get_device_adaptor() -> type["BaseDeviceAdaptor"]:
ascend_device_type = get_ascend_device_type()
if ascend_device_type == AscendDeviceType.A5:
return A5DeviceAdaptor
return BaseDeviceAdaptor


DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor()
DeviceOperator: type["BaseDeviceAdaptor"] = get_device_adaptor()
Loading
Loading