Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ def apply_without_routing_weights(

class NPUW4A8Int4DynamicMoEMethod(FusedMoEMethodBase):

def __init__(self) -> None:
def __init__(self, activation_use_clip: bool) -> None:
self.group_size = 0
self.tp_size = 1
self.activation_use_clip = activation_use_clip

def create_weights(
self,
Expand Down Expand Up @@ -366,17 +367,21 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

# >> scale
weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=weight_scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)

w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
Expand All @@ -400,6 +405,77 @@ def create_weights(
set_weight_attrs(w2_weight_offset, extra_weight_attrs)

# >>> special param for w4a8
if self.activation_use_clip:
self._init_activation_clip_params(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
extra_weight_attrs,
)
else:
self._init_extra_scale_params(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
extra_weight_attrs,
)

def _init_activation_clip_params(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
extra_weight_attrs: dict,
) -> None:
"""
Initializes bias and alpha parameters for quantization schemes that use activation clipping.

This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to
shift and scale the activations or outputs to compensate for the precision loss
introduced by clamping activations.
"""
w13_bias = torch.nn.Parameter(
torch.ones(
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)

w2_bias = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)

w2_alpha = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float), requires_grad=False
)
layer.register_parameter("w2_alpha", w2_alpha)
set_weight_attrs(w2_alpha, extra_weight_attrs)

def _init_extra_scale_params(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
extra_weight_attrs: dict,
) -> None:
"""
Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping.

This method registers the following parameters:
1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`.
2. Secondary Quantization Params (initialized only for grouped quantization):
`w13_weight_scale_second`, `w13_weight_offset_second`,
`w2_weight_scale_second`, and `w2_weight_offset_second`.
"""
if not self.is_per_channel_weight:
w13_weight_scale_second = torch.nn.Parameter(
torch.empty(
Expand All @@ -412,6 +488,7 @@ def create_weights(
)
layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second)
set_weight_attrs(w13_weight_scale_second, extra_weight_attrs)

w13_weight_offset_second = torch.nn.Parameter(
torch.empty(
num_experts,
Expand Down Expand Up @@ -515,13 +592,25 @@ def pack_to_int32(self, weight: torch.Tensor):
return weight.view(torch.int32).contiguous()

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if not self.activation_use_clip:
self._process_weights_without_clip(layer)
else:
self._process_weights_with_clip(layer)

layer.w13_weight = torch.nn.Parameter(
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)

layer.w13_weight.data = npu_format_cast(layer.w13_weight.data)
layer.w2_weight.data = npu_format_cast(layer.w2_weight.data)

layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

def _process_weights_without_clip(self, layer: torch.nn.Module) -> None:
w13_weight_scale_second = (
layer.w13_weight_scale_second.data
if hasattr(layer, "w13_weight_scale_second")
Expand All @@ -547,10 +636,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

self.update_bias(layer, w13_bias, w2_bias)

layer.w13_weight.data = npu_format_cast(layer.w13_weight.data)
layer.w2_weight.data = npu_format_cast(layer.w2_weight.data)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
def _process_weights_with_clip(self, layer: torch.nn.Module) -> None:
w13_weight_scale = (
layer.w13_weight_scale.data.squeeze(-1).contiguous().unsqueeze(1)
)
w2_weight_scale = (
layer.w2_weight_scale.data.squeeze(-1).contiguous().unsqueeze(1)
)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w13_scale_bias = layer.w13_bias
layer.w2_scale_bias = layer.w2_bias

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(self, quant_config: Dict[str, Any] = {}):
self.packed_modules_mapping = (
packed_modules_mapping if packed_modules_mapping is not None else {}
)
self.activation_use_clip = (
self.quant_description.get("config_groups", {})
.get("group_1", {})
.get("activation_use_clip", False)
)
self.target_scheme_map = (
CompressedTensorsConfig._quantization_scheme_map_from_config(
config=quant_config
Expand Down Expand Up @@ -180,7 +185,9 @@ def get_quant_method(
if (
self.is_moe_w4_dynamic and self.is_moe_input_quant is not None
) or is_moe_w4a8_dynamic:
return NPUW4A8Int4DynamicMoEMethod()
return NPUW4A8Int4DynamicMoEMethod(
activation_use_clip=self.activation_use_clip
)
elif self.is_moe_w4_dynamic and self.is_moe_input_quant is None:
return NPUW4A16Int4DynamicMoEMethod(self)
else:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,18 @@ def _weight_loader_impl(
)
return

if (
"bias" in weight_name
and self.quant_config.quant_description["quant_method"] == "modelslim"
):
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)

def weight_loader_fused(
self,
param: torch.nn.Parameter,
Expand Down
Loading