-
Notifications
You must be signed in to change notification settings - Fork 128
[FIX_FOR_VLLM_CUSTOM=f976e3b98ba45677a2213673a442c6cbff141e8e] Fix upstream regressions in attention, FP8, offloading and platform #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
adb7220
c2ae5a6
b4cc45a
0456f22
ae16ecb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.config import VllmConfig, set_current_vllm_config | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def default_vllm_config(): | ||
| """VllmConfig with a minimal model_config stub. | ||
|
|
||
| Upstream Fp8LinearMethod.__init__ accesses model_config.dtype. | ||
| We provide a SimpleNamespace with the attributes required by quantization | ||
| methods so that ops-level unit tests can run without a full model setup. | ||
| """ | ||
| vllm_config = VllmConfig() | ||
| vllm_config.model_config = SimpleNamespace(dtype=torch.bfloat16, is_moe=False, hf_config=None, quantization=None) | ||
|
|
||
| with set_current_vllm_config(vllm_config): | ||
| yield |
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -69,19 +69,7 @@ def forward( | |||||||||||||||
| # self.kv_cache_dtype, | ||||||||||||||||
| # self._k_scale, | ||||||||||||||||
| #) | ||||||||||||||||
| if self.attn_backend.accept_output_buffer: | ||||||||||||||||
| output = torch.empty(output_shape, dtype=q.dtype, device=q.device) | ||||||||||||||||
| self.forward_impl( | ||||||||||||||||
| q, | ||||||||||||||||
| kv_c_normed, | ||||||||||||||||
| k_pe, | ||||||||||||||||
| self_kv_cache, | ||||||||||||||||
| attn_metadata, | ||||||||||||||||
| output=output, | ||||||||||||||||
| ) | ||||||||||||||||
| return output | ||||||||||||||||
| else: | ||||||||||||||||
| return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata) | ||||||||||||||||
| return self.forward_impl(q, kv_c_normed, k_pe, self_kv_cache, attn_metadata) | ||||||||||||||||
| else: | ||||||||||||||||
| kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( | ||||||||||||||||
| kv_c_normed, | ||||||||||||||||
|
|
@@ -90,25 +78,16 @@ def forward( | |||||||||||||||
| self.kv_cache_dtype, | ||||||||||||||||
| self._k_scale, | ||||||||||||||||
| ) | ||||||||||||||||
| if self.attn_backend.accept_output_buffer: | ||||||||||||||||
| output = torch.empty(output_shape, dtype=q.dtype, device=q.device) | ||||||||||||||||
| torch.ops.vllm.unified_mla_attention_with_output( | ||||||||||||||||
| q, | ||||||||||||||||
| kv_c_normed, | ||||||||||||||||
| k_pe, | ||||||||||||||||
| output, | ||||||||||||||||
| self.layer_name, | ||||||||||||||||
| kv_cache_dummy_dep=kv_cache_dummy_dep, | ||||||||||||||||
| ) | ||||||||||||||||
| return output | ||||||||||||||||
| else: | ||||||||||||||||
| return torch.ops.vllm.unified_mla_attention( | ||||||||||||||||
| q, | ||||||||||||||||
| kv_c_normed, | ||||||||||||||||
| k_pe, | ||||||||||||||||
| self.layer_name, | ||||||||||||||||
| kv_cache_dummy_dep=kv_cache_dummy_dep, | ||||||||||||||||
| ) | ||||||||||||||||
| output = torch.empty(output_shape, dtype=q.dtype, device=q.device) | ||||||||||||||||
| torch.ops.vllm.unified_mla_attention_with_output( | ||||||||||||||||
| q, | ||||||||||||||||
| kv_c_normed, | ||||||||||||||||
| k_pe, | ||||||||||||||||
| output, | ||||||||||||||||
| self.layer_name, | ||||||||||||||||
| kv_cache_dummy_dep=kv_cache_dummy_dep, | ||||||||||||||||
| ) | ||||||||||||||||
| return output | ||||||||||||||||
|
|
||||||||||||||||
| def forward_impl( | ||||||||||||||||
| self, | ||||||||||||||||
|
|
@@ -121,9 +100,6 @@ def forward_impl( | |||||||||||||||
| output_scale: torch.Tensor | None = None, | ||||||||||||||||
| output_block_scale: torch.Tensor | None = None, | ||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
|
||||||||||||||||
| ) -> torch.Tensor: | |
| ) -> torch.Tensor: | |
| if (output is not None or output_scale is not None | |
| or output_block_scale is not None): | |
| raise NotImplementedError( | |
| "HPUMLAAttention.forward_impl does not support caller-" | |
| "provided output, output_scale, or output_block_scale.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_store_outputandMockLoadStoreSpecstill useblock_hashesnaming, but the returnedPrepareStoreOutputnow uses the renamedkeys_to_store/evicted_keysfields. Consider renaming the function parameter/local variables (and related mock spec fields, if appropriate) tokeysto match the updated API and avoid confusion when reading or extending these tests.