Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if _use_aiter:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight

from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1


if _is_cuda:
from sgl_kernel import fused_marlin_moe
Expand Down Expand Up @@ -292,7 +292,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No
max_w13_scales, requires_grad=False
)

if _use_aiter:
if self.weight_quant.strategy == QuantizationStrategy.CHANNEL and _use_aiter:
with torch.no_grad():
# Pre-shuffle weights
layer.w13_weight = torch.nn.Parameter(
Expand Down Expand Up @@ -325,23 +325,33 @@ def apply(

moe_runner_config = self.moe_runner_config

if (
_use_aiter
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and moe_runner_config.apply_router_weight_on_input
):
if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
assert not moe_runner_config.no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output
output = rocm_fused_experts_tkw1(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=moe_runner_config.activation,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
if moe_runner_config.apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
Comment on lines +349 to +353
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for selecting the activation function defaults to ActivationType.Gelu for any activation string other than 'silu'. This could lead to unexpected behavior or silent errors if an unsupported activation function is provided. It would be more robust to explicitly handle only the supported activations and raise an error for any others.

For example:

if moe_runner_config.activation == "silu":
    activation = ActivationType.Silu
elif moe_runner_config.activation == "gelu":
    activation = ActivationType.Gelu
else:
    raise ValueError(f"Unsupported activation: {moe_runner_config.activation}")

This change would make the code safer and easier to debug.

quant_type=QuantType.per_Token,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
Expand Down
Loading