-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[feat] Add support to load fp8 Meta Llama4 weights #6268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,6 +100,10 @@ def load_expert_weights_to_dst(self, module: torch.nn.Module, | |
| w1_weight = weights[f"{expert_id}.w1.weight"] | ||
| w3_weight = weights[f"{expert_id}.w3.weight"] | ||
| w2_weight = weights[f"{expert_id}.w2.weight"] | ||
| elif weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE: | ||
| w1_weight = weights[f"{expert_id}.gate_proj.weight"] | ||
| w3_weight = weights[f"{expert_id}.up_proj.weight"] | ||
| w2_weight = weights[f"{expert_id}.down_proj.weight"] | ||
| elif weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: | ||
| w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( | ||
| 0, 1) | ||
|
|
@@ -364,6 +368,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| w1_input_scale = weights[f"{expert_id}.w1.input_scale"] | ||
| w3_input_scale = weights[f"{expert_id}.w3.input_scale"] | ||
| w2_input_scale = weights[f"{expert_id}.w2.input_scale"] | ||
| elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE: | ||
| # HACK: dynamic quantization so no input scale - create empty tensors for testing | ||
| w1_input_scale = torch.tensor(1.0, dtype=torch.float32) | ||
| w3_input_scale = torch.tensor(1.0, dtype=torch.float32) | ||
| w2_input_scale = torch.tensor(1.0, dtype=torch.float32) | ||
| elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: | ||
| w1_input_scale = weights[f"gate_up_proj_input_scale"] | ||
| w3_input_scale = weights[f"gate_up_proj_input_scale"] | ||
|
|
@@ -397,6 +406,11 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] | ||
| w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] | ||
| w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] | ||
| elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE: | ||
| # HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing | ||
| w1_weight_scale = torch.tensor(1.0, dtype=torch.float32) | ||
| w3_weight_scale = torch.tensor(1.0, dtype=torch.float32) | ||
| w2_weight_scale = torch.tensor(1.0, dtype=torch.float32) | ||
|
Comment on lines
+409
to
+413
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainAddress the underlying shape mismatch instead of using hardcoded values. The hardcoded weight scales could lead to suboptimal quantization. The comment indicates a shape mismatch with the "blockwise recipe" that should be addressed properly. Consider addressing the root cause: elif module.weight_loading_mode == MoEWeightLoadingMode.META_FP8_RECIPE:
- # HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing
- w1_weight_scale = torch.tensor(1.0, dtype=torch.float32)
- w3_weight_scale = torch.tensor(1.0, dtype=torch.float32)
- w2_weight_scale = torch.tensor(1.0, dtype=torch.float32)
+ # Load and reshape weight scales to match expected format
+ w1_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "gate_proj")
+ w3_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "up_proj")
+ w2_weight_scale = self._load_and_reshape_weight_scale(weights, expert_id, "down_proj")Let me verify if there are existing implementations that handle dynamic quantization properly: 🏁 Script executed: #!/bin/bash
# Search for dynamic quantization implementations and scale computation methods
rg -A 10 -B 5 "dynamic.*quantization|compute.*scale" --type pyLength of output: 45171 Replace hardcoded FP8 weight‐scale hack with proper loading and reshaping The current META_FP8_RECIPE branch (lines 409–413 in
For example, instead of: - # HACK: weight scale is not the right shape in blockwise recipe - create empty tensors for testing
- w1_weight_scale = torch.tensor(1.0, dtype=torch.float32)
- w3_weight_scale = torch.tensor(1.0, dtype=torch.float32)
- w2_weight_scale = torch.tensor(1.0, dtype=torch.float32)
+ # TODO: load and reshape the per-expert FP8 weight scales for blockwise META_FP8_RECIPE
+ raw_w1 = weights[f"{expert_id}.w1.weight_scale"]
+ raw_w3 = weights[f"{expert_id}.w3.weight_scale"]
+ raw_w2 = weights[f"{expert_id}.w2.weight_scale"]
+ w1_weight_scale = raw_w1.view(expected_block_shape).to(torch.float32)
+ w3_weight_scale = raw_w3.view(expected_block_shape).to(torch.float32)
+ w2_weight_scale = raw_w2.view(expected_block_shape).to(torch.float32)• File: Implement and test this loader against your blockwise FP8 metadata to ensure correct shapes and quantization fidelity.
🤖 Prompt for AI Agents |
||
| elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: | ||
| w1_weight_scale = weights[f"gate_up_proj_weight_scale"] | ||
| w3_weight_scale = weights[f"gate_up_proj_weight_scale"] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace hardcoded scale values with proper dynamic quantization implementation.
The current implementation uses hardcoded 1.0 values as a workaround for dynamic quantization. This approach could lead to incorrect quantization behavior and reduced model accuracy.
Consider implementing proper dynamic quantization scale computation instead:
🤖 Prompt for AI Agents