Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,21 @@ def from_pretrained(cls,
128,
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
elif hf_quant_config.get(
"quant_method"
) == "compressed-tensors" and hf_quant_config.get(
"format") == "float-quantized":
# Llama4 FP8 ckpt
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
exclude_modules = hf_quant_config.get("ignore", None)
exclude_modules = [
s.removeprefix("language_model.") for s in exclude_modules
]
exclude_modules = [
s.replace("q_proj", "qkv_proj") for s in exclude_modules
]
quant_config.exclude_modules = exclude_modules
quant_config.use_meta_recipe = True

model_config = cls(pretrained_config=pretrained_config,
quant_config=quant_config,
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(
overridden_tp_size=1 if self.enable_attention_dp else None,
reduce_output=False)

weight_loading_mode = MoEWeightLoadingMode.META_FP8_RECIPE if model_config.quant_config.use_meta_recipe else MoEWeightLoadingMode.FUSED_GATE_UP_PROJ
self.experts = create_moe(
routing_method=Llama4RenormalizeMoeRoutingMethod(top_k),
num_experts=num_experts,
Expand All @@ -281,7 +282,7 @@ def __init__(
dtype=dtype,
reduce_results=
False, # In both low latency and max-throughput scenarios, FusedMoE needs not to do allreduce inside op.
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
weight_loading_mode=weight_loading_mode,
model_config=model_config,
apply_router_weight_on_input=True,
layer_idx=layer_idx)
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _check_configs(self):
if not (self.quant_config.quant_mode.has_nvfp4()
| self.quant_config.quant_mode.has_fp8_block_scales()
| self.quant_config.quant_mode.has_fp8_qdq()
| self.quant_config.quant_mode.has_fp8_rowwise()
| self.quant_config.quant_mode.
is_int4_weight_only_per_group()):
raise ValueError(
Expand All @@ -145,7 +146,8 @@ def has_w4afp8(self):
def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
if self.quant_config.layer_quant_mode.has_fp8_qdq():
if self.quant_config.layer_quant_mode.has_fp8_qdq(
) or self.quant_config.quant_mode.has_fp8_rowwise():
return FP8QDQFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethod()
Expand Down Expand Up @@ -226,7 +228,7 @@ def forward_chunk(
weight_dtype = self.w3_w1_weight.dtype
x_sf = None
if self.has_any_quant:
if self.has_fp8_qdq:
if self.has_fp8_qdq or self.has_fp8_rowwise:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant)
elif self.has_deepseek_fp8_block_scales:
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class MoEWeightLoadingMode(Enum):
VANILLA = 0
FUSED_GATE_UP_PROJ = 1
META_FP8_RECIPE = 2


class MoE(nn.Module):
Expand Down Expand Up @@ -111,6 +112,12 @@ def has_fp8_qdq(self):
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_qdq(
)

@property
def has_fp8_rowwise(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_rowwise(
)

@property
def has_deepseek_fp8_block_scales(self):
assert self._weights_created
Expand Down
14 changes: 14 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def load_expert_weights_to_dst(self, module: torch.nn.Module,
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
elif weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
w1_weight = weights[f"{expert_id}.gate_proj.weight"]
w3_weight = weights[f"{expert_id}.up_proj.weight"]
w2_weight = weights[f"{expert_id}.down_proj.weight"]
elif weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_w3_weight = weights["gate_up_proj"][expert_id].transpose(
0, 1)
Expand Down Expand Up @@ -364,6 +368,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
# HACK: dynamic quantization so no input scale - create empty tensors for testing
w1_input_scale = torch.tensor(1.0, dtype=torch.float32)
w3_input_scale = torch.tensor(1.0, dtype=torch.float32)
w2_input_scale = torch.tensor(1.0, dtype=torch.float32)
Comment on lines +371 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace hardcoded scale values with proper dynamic quantization implementation.

The current implementation uses hardcoded 1.0 values as a workaround for dynamic quantization. This approach could lead to incorrect quantization behavior and reduced model accuracy.

Consider implementing proper dynamic quantization scale computation instead:

 elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
-    # HACK: dynamic quantization so no input scale - create empty tensors for testing
-    w1_input_scale = torch.tensor(1.0, dtype=torch.float32)
-    w3_input_scale = torch.tensor(1.0, dtype=torch.float32)
-    w2_input_scale = torch.tensor(1.0, dtype=torch.float32)
+    # TODO: Implement proper dynamic quantization scale computation
+    # For now, compute scales dynamically or load from appropriate keys
+    w1_input_scale = self._compute_dynamic_input_scale(weights, expert_id, "gate_proj")
+    w3_input_scale = self._compute_dynamic_input_scale(weights, expert_id, "up_proj")  
+    w2_input_scale = self._compute_dynamic_input_scale(weights, expert_id, "down_proj")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 371 to
375, the code uses hardcoded scale values of 1.0 for dynamic quantization, which
is a placeholder and can cause incorrect quantization results. Replace these
hardcoded tensors with a proper dynamic quantization scale computation by
calculating the input scales based on the actual data or model parameters to
ensure accurate quantization behavior.

elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_input_scale = weights[f"gate_up_proj_input_scale"]
w3_input_scale = weights[f"gate_up_proj_input_scale"]
Expand Down Expand Up @@ -397,6 +406,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
# HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing
w1_weight_scale = torch.tensor(1.0, dtype=torch.float32)
w3_weight_scale = torch.tensor(1.0, dtype=torch.float32)
w2_weight_scale = torch.tensor(1.0, dtype=torch.float32)
Comment on lines +409 to +413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Address the underlying shape mismatch instead of using hardcoded values.

The hardcoded weight scales could lead to suboptimal quantization. The comment indicates a shape mismatch with the "blockwise recipe" that should be addressed properly.

Consider addressing the root cause:

 elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
-    # HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing
-    w1_weight_scale = torch.tensor(1.0, dtype=torch.float32)
-    w3_weight_scale = torch.tensor(1.0, dtype=torch.float32)
-    w2_weight_scale = torch.tensor(1.0, dtype=torch.float32)
+    # Load and reshape weight scales to match expected format
+    w1_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "gate_proj")
+    w3_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "up_proj")
+    w2_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "down_proj")

Let me verify if there are existing implementations that handle dynamic quantization properly:


🏁 Script executed:

#!/bin/bash
# Search for dynamic quantization implementations and scale computation methods
rg -A 10 -B 5 "dynamic.*quantization|compute.*scale" --type py

Length of output: 45171


Replace hardcoded FP8 weight‐scale hack with proper loading and reshaping

The current META_FP8_RECIPE branch (lines 409–413 in
tensorrt_llm/_torch/modules/fused_moe/quantization.py) uses placeholder scalars for all three weight scales. This bypasses actual per-expert, blockwise scale data and will degrade quantization accuracy. Please implement a loader that:

  • Reads each expert’s stored weight scale (weights["{expert_id}.w1.weight_scale"], etc.)
  • Reshapes it to match the blockwise partitioning used by the fused MoE kernels
  • Registers or returns the correctly shaped tensor for w1_weight_scale, w3_weight_scale, and w2_weight_scale

For example, instead of:

-    # HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing
-    w1_weight_scale = torch.tensor(1.0, dtype=torch.float32)
-    w3_weight_scale = torch.tensor(1.0, dtype=torch.float32)
-    w2_weight_scale = torch.tensor(1.0, dtype=torch.float32)
+    # TODO: load and reshape the per-expert FP8 weight scales for blockwise META_FP8_RECIPE
+    raw_w1 = weights[f"{expert_id}.w1.weight_scale"]
+    raw_w3 = weights[f"{expert_id}.w3.weight_scale"]
+    raw_w2 = weights[f"{expert_id}.w2.weight_scale"]
+    w1_weight_scale = raw_w1.view(expected_block_shape).to(torch.float32)
+    w3_weight_scale = raw_w3.view(expected_block_shape).to(torch.float32)
+    w2_weight_scale = raw_w2.view(expected_block_shape).to(torch.float32)

• File: tensorrt_llm/_torch/modules/fused_moe/quantization.py
• Lines: ~409–413

Implement and test this loader against your blockwise FP8 metadata to ensure correct shapes and quantization fidelity.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 409 to
413, replace the hardcoded scalar tensors used as placeholders for
w1_weight_scale, w2_weight_scale, and w3_weight_scale with actual loading logic.
Implement code to read each expert's stored weight scale from weights using keys
like "{expert_id}.w1.weight_scale", reshape these tensors to match the blockwise
partitioning expected by the fused MoE kernels, and assign or return these
properly shaped tensors instead of the placeholders. Ensure the loader correctly
handles all experts and weight scales to maintain quantization accuracy.

elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_weight_scale = weights[f"gate_up_proj_weight_scale"]
w3_weight_scale = weights[f"gate_up_proj_weight_scale"]
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]):

scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
module.tp_rank,
module.tp_mode).squeeze()
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
Expand Down Expand Up @@ -533,7 +534,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat((left_scale, right_scale))
fused_scale = torch.cat((left_scale, right_scale)).squeeze()
copy_weight(module.weight_scale, fused_scale)


Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ def _update_from_hf_quant_config(self) -> bool:
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
elif hf_quant_config.get(
"quant_method"
) == "compressed-tensors" and hf_quant_config.get(
"format") == "float-quantized":
# Llama4 FP8 ckpt
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")
Expand Down