Conversation
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review. Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed. |
|
The quality checks have failed. Please run |
There was a problem hiding this comment.
Code Review
This pull request refactors MoE calibration by replacing model-specific modules with a generic linearization framework that unfuses expert weights into standard nn.Linear layers. Feedback identifies critical bugs, such as missing imports in the GPT-OSS module and incorrect handling of 3D input tensors in the LinearExperts forward pass. Further improvements were suggested regarding the fragility of using source code inspection for module detection and the efficiency of the gated MLP implementation.
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| top_k_index: torch.Tensor, | ||
| top_k_weights: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| final_hidden_states = torch.zeros_like(hidden_states) | ||
| num_experts = len(self) | ||
|
|
||
| # create tokens mask | ||
| with torch.no_grad(): | ||
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_experts) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
|
|
||
| for expert_idx in range(num_experts): | ||
| # select tokens for this expert | ||
| top_k_pos, token_indices = torch.where(expert_mask[expert_idx]) | ||
| if token_indices.numel() == 0: | ||
| continue | ||
|
|
||
| # apply expert, maybe pass all tokens to the expert | ||
| expert = self[expert_idx] | ||
| if context.CALIBRATE_ALL_EXPERTS: | ||
| expert_output = expert(hidden_states)[token_indices] | ||
| else: | ||
| expert_output = expert(hidden_states[token_indices]) | ||
|
|
||
| # apply weighting to outputs | ||
| expert_weights = top_k_weights[token_indices, top_k_pos, None] | ||
| weighted_output = expert_output * expert_weights | ||
|
|
||
| # accumulate the selected tokens | ||
| final_hidden_states.index_add_(0, token_indices, weighted_output) | ||
|
|
||
| return final_hidden_states |
There was a problem hiding this comment.
The forward method of LinearExperts needs to handle 3D hidden_states (e.g., [batch, sequence, hidden]) by flattening them before processing. Otherwise, index_add_(0, token_indices, ...) will incorrectly index into the batch dimension instead of the token dimension, leading to incorrect results or out-of-bounds errors. Additionally, an explicit cast to the destination dtype is recommended for weighted_output to ensure compatibility with index_add_ when using mixed precision (e.g., float32 weights with bfloat16 states).
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
orig_shape = hidden_states.shape
# Flatten to (total_tokens, hidden_dim)
hidden_states = hidden_states.view(-1, orig_shape[-1])
top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = len(self)
# create tokens mask
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
for expert_idx in range(num_experts):
# select tokens for this expert
top_k_pos, token_indices = torch.where(expert_mask[expert_idx])
if token_indices.numel() == 0:
continue
# apply expert, maybe pass all tokens to the expert
expert = self[expert_idx]
if context.CALIBRATE_ALL_EXPERTS:
expert_output = expert(hidden_states)[token_indices]
else:
expert_output = expert(hidden_states[token_indices])
# apply weighting to outputs
expert_weights = top_k_weights[token_indices, top_k_pos, None]
weighted_output = expert_output * expert_weights
# accumulate the selected tokens
final_hidden_states.index_add_(
0, token_indices, weighted_output.to(final_hidden_states.dtype)
)
return final_hidden_states.view(orig_shape)There was a problem hiding this comment.
Hidden states is expected to be already flattened
https://github.com/huggingface/transformers/blob/a66638d854ae536e0ca31e8bcfa480adfaf58284/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py#L141
| def _is_moe_experts_module(module) -> bool: | ||
| """Detect modules whose class is decorated with | ||
| ``@use_experts_implementation`` by inspecting the class source AST.""" | ||
| try: | ||
| source = inspect.getsource(type(module)) | ||
| tree = ast.parse(source) | ||
| except (OSError, TypeError): | ||
| return False | ||
|
|
||
| for node in ast.iter_child_nodes(tree): | ||
| if not isinstance(node, ast.ClassDef): | ||
| continue | ||
| for decorator in node.decorator_list: | ||
| if isinstance(decorator, ast.Name): | ||
| name = decorator.id | ||
| elif isinstance(decorator, ast.Call) and isinstance( | ||
| decorator.func, ast.Name | ||
| ): | ||
| name = decorator.func.id | ||
| else: | ||
| continue | ||
| if name == "use_experts_implementation": | ||
| return True | ||
|
|
||
| return False |
There was a problem hiding this comment.
Using inspect.getsource and ast.parse to detect MoE modules is fragile and potentially slow. It will fail if the source code is unavailable (e.g., in some deployment environments) and adds overhead for every module in the model. A more robust approach would be to check for specific attributes (like gate_up_proj and down_proj) or use a more direct way to identify these modules if possible.
There was a problem hiding this comment.
Yep, will think more on this. I think this is the most robust solution, but should probably be lru cached against the module class.
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| return self.down_proj( | ||
| self._apply_gate( | ||
| torch.cat( | ||
| [self.gate_proj(hidden_states), self.up_proj(hidden_states)], dim=-1 | ||
| ) | ||
| ) | ||
| ) |
There was a problem hiding this comment.
The use of torch.cat followed by _apply_gate (which typically performs a chunk operation) is inefficient. Since _apply_gate is currently restricted to _default_apply_gate (which just splits the input and multiplies), you could optimize this by applying the gate logic directly to the separate projection outputs, avoiding the concatenation and subsequent chunking.
| ) | ||
|
|
||
|
|
||
| def linearize_moe_model(model: PreTrainedModel): |
There was a problem hiding this comment.
In the base case, we need to do a conversion. However, we can optimize this on a per-model basis by writing our own weight converters which are used directly during loading
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
|
The quality checks have failed. Please run |
|
Thanks for working on this @kylesayrs. The IBM Spyre stack for vLLM depends on llm-compressor because of this library: https://github.com/foundation-model-stack/fms-model-optimizer. Currently we're unable to upgrade to transformers 5 due to this. Do you need any help? |
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Examples Changes
MoE