Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def __init__(
self.out_dtype = moe_config.in_dtype
self.num_local_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# FC2 input scale tensor bound in process_weights_after_loading: the
# calibrated (now-zeroed) a2_gscale for static-quant checkpoints, or
# a synthesized uniform-1.0 tensor for W4A16 checkpoints that lack
# one. Holding it on the instance keeps apply() alloc-free.
self._fc2_input_scale: torch.Tensor | None = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Normalise block scales to absorb the per-expert weight global scale
Expand Down Expand Up @@ -86,6 +91,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# its own per-block dynamic scale.
if self.a2_gscale is not None:
self.a2_gscale.fill_(1.0)
self._fc2_input_scale = self.a2_gscale
else:
# W4A16 NVFP4 checkpoints have no calibrated a2_gscale; b12x
# performs dynamic per-block FC2-input quantization, so a uniform
# 1.0 scale per expert is equivalent to the bake-in above for
# static-quant checkpoints. Allocate once here so apply() stays
# alloc-free.
self._fc2_input_scale = torch.ones(
self.num_local_experts,
device=layer.w13_weight.device,
dtype=torch.float32,
)

# Precompute MMA-layout views of the weight scale factors once here
# rather than recomputing on every forward pass.
Expand Down Expand Up @@ -131,7 +148,13 @@ def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
# b12x performs in-kernel BF16->FP4 activation quant, so W4A16
# NVFP4 checkpoints (activation_key=None, e.g. mixed-precision
# compressed-tensors layouts) are runtime-compatible.
return (weight_key, activation_key) in (
(kNvfp4Static, kNvfp4Dynamic),
(kNvfp4Static, None),
)
Comment thread
ECMGit marked this conversation as resolved.

@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
Expand Down Expand Up @@ -198,8 +221,8 @@ def apply(
assert self.g1_alphas is not None and self.g2_alphas is not None, (
"g1_alphas and g2_alphas must not be None for FlashInferB12xExperts"
)
assert self.a2_gscale is not None, (
"a2_gscale must not be None for FlashInferB12xExperts"
assert self._fc2_input_scale is not None, (
"_fc2_input_scale must be set by process_weights_after_loading"
)

top_k = topk_ids.shape[1]
Expand All @@ -211,7 +234,7 @@ def apply(
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
fc2_input_scale=self._fc2_input_scale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
Expand Down
Loading