Initial Commit GPT-OSS#485
Conversation
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
| if self.bias is not None: | ||
| w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] | ||
| w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] | ||
| return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states, |
There was a problem hiding this comment.
Test fails with:
"The underlying op of 'hpu.mixture_of_experts' has no overload name 'bias_fused_weights'. Did you mean: 'fp8_fused_weights'" please fix
There was a problem hiding this comment.
The CI is on 1.22.0; this needs 1.23.0 software, that's the reason it's failing; we can merge this only after CI moves to 1.23.0 release
There was a problem hiding this comment.
Pull Request Overview
This PR enables GPT-OSS model support with two main features: attention sinks for improved context handling and bias support in Mixture of Experts (MoE) layers.
Key Changes:
- Added sink attention mechanism to handle long-context scenarios across naive, FSDPA, and flat attention implementations
- Implemented bias support in MoE operations for models requiring biased expert computations
- Added model-specific routing logic for GPT-OSS in the MoE forward pass
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_gaudi/ops/hpu_fused_moe.py | Added bias handling in MoE layers and GPT-OSS specific router weight processing |
| vllm_gaudi/extension/utils.py | Extended FSDPA forward method to accept sinks parameter |
| vllm_gaudi/extension/ops.py | Implemented sink attention logic across multiple attention implementations and added bias support to MoE operations |
| vllm_gaudi/attention/backends/hpu_attn.py | Added sinks parameter to attention implementations with validation and dtype conversion |
| tests/unit_tests/sinks/test_gpt_oss.py | Added integration test for GPT-OSS model with expected outputs |
Comments suppressed due to low confidence (2)
vllm_gaudi/attention/backends/hpu_attn.py:1
- Missing space after '#' in comment. Should be '# causal' for proper comment formatting.
# SPDX-License-Identifier: Apache-2.0
vllm_gaudi/attention/backends/hpu_attn.py:1
- Inconsistent TODO format: should be 'TODO:' with a colon instead of a dash.
# SPDX-License-Identifier: Apache-2.0
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| w12=w1_list, | ||
| w3=w2_list, | ||
| w12_bias=w1_bias_list_slice, | ||
| w3_bias=w2_bias_list_slice, | ||
| permuted_weights=permuted_weights, | ||
| experts_min=self.experts_min, | ||
| experts_max=self.experts_max) |
There was a problem hiding this comment.
Incorrect weight lists passed to MoE operation. Should use sliced lists w1_list_slice and w2_list_slice instead of full lists w1_list and w2_list to match the expert range being processed.
| w12=w1_list, | |
| w3=w2_list, | |
| w12_bias=w1_bias_list_slice, | |
| w3_bias=w2_bias_list_slice, | |
| permuted_weights=permuted_weights, | |
| experts_min=self.experts_min, | |
| experts_max=self.experts_max) | |
| w12=w1_list_slice, | |
| w3=w2_list_slice, | |
| w12_bias=w1_bias_list_slice, | |
| w3_bias=w2_bias_list_slice, | |
| permuted_weights=permuted_weights, | |
| experts_min=min_expert, | |
| experts_max=max_expert) |
| experts_min=self.experts_min, | ||
| experts_max=self.experts_max) |
There was a problem hiding this comment.
Incorrect expert range parameters. Should use min_expert and max_expert (computed for the current slice) instead of self.experts_min and self.experts_max to correctly process the expert slice.
| experts_min=self.experts_min, | |
| experts_max=self.experts_max) | |
| experts_min=min_expert, | |
| experts_max=max_expert) |
| # TODO - change 128 to proper window size | ||
| window_size = ( | ||
| 128, |
There was a problem hiding this comment.
Magic number 128 used for window size. Consider defining this as a named constant or deriving it from self.sliding_window as indicated by the TODO comment.
| # TODO - change 128 to proper window size | |
| window_size = ( | |
| 128, | |
| # Use self.sliding_window for window size instead of hardcoded 128 | |
| window_size = ( | |
| self.sliding_window, |
| tensor_parallel_size=4, | ||
| ) | ||
| generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1) | ||
| assert generated_texts == expected_output |
There was a problem hiding this comment.
Assertion compares single generated text with expected output incorrectly. The function returns a list but only validates the first element earlier. This assertion will fail unless generated_texts contains exactly one element matching expected_output[0]. Consider assert generated_texts[0] == expected_output[0] or assert generated_texts == expected_output after validating the list length.
| assert generated_texts == expected_output | |
| assert len(generated_texts) == len(expected_output) | |
| assert generated_texts[0] == expected_output[0] |
| attn_sink = attn_sink.exp() | ||
| if attn_sink.dtype == torch.float32: | ||
| attn_sink = attn_sink.to(value.dtype) | ||
| #TODO: Removing this .sum and using attn_sink directly |
There was a problem hiding this comment.
Corrected spacing in TODO comment: should be 'TODO:' with a space after the colon for consistency.
| #TODO: Removing this .sum and using attn_sink directly | |
| # TODO: Removing this .sum and using attn_sink directly |
| attn_bias = None | ||
| window_size = (self.sliding_window, 0) | ||
| common_args['window_size'] = window_size | ||
| # TODO - change 128 to proper window size |
There was a problem hiding this comment.
Inconsistent TODO format: should be 'TODO:' with a colon instead of a dash for consistency with project conventions.
| # TODO - change 128 to proper window size | |
| # TODO: change 128 to proper window size |
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <hlahkar@habana.ai>
Signed-off-by: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com>
|
Tracking this with #771; as there are lot of changes due to latest vllm plugin |
This enables GPT OSS with naive attention. Features enabled: