-
-
Notifications
You must be signed in to change notification settings - Fork 17.7k
[ROCm][DSv4] Functional fixes for DeepSeek V4 on MI300X (gfx942) #42893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
54084ea
6914698
3b4a8eb
b65601e
f677c49
54536ca
4da0919
f13e445
9ae8703
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| DeepseekV4FlashMLASparseBackend, | ||
| DeepseekV4SparseMLAAttentionImpl, | ||
| ) | ||
| from vllm.platforms import current_platform | ||
| from vllm.triton_utils import tl, triton | ||
| from vllm.v1.attention.backend import ( | ||
| CommonAttentionMetadata, | ||
|
|
@@ -789,6 +790,11 @@ def _forward_prefill( | |
| kv = workspace_manager.get_simultaneous( | ||
| ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), | ||
| )[0] | ||
| # TODO: workspace is torch.empty() and only the compressed-K prefix + | ||
| # SWA window are written per chunk row; the indexer's topK can land in | ||
| # the unwritten holes for short sequences. Proper fix is to mask invalid | ||
| # rows in the indexer (score = -inf) or in rocm_sparse_attn_prefill. | ||
| kv.zero_() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling |
||
| for chunk_idx in range(num_chunks): | ||
| chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE | ||
| chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills) | ||
|
|
@@ -797,6 +803,7 @@ def _forward_prefill( | |
| assert attn_metadata is not None | ||
| assert compressed_k_cache is not None | ||
| block_table = attn_metadata.block_table[num_decodes:] | ||
| # compressed_k_cache is OCP on every platform (Triton encoder). | ||
| dequantize_and_gather_k_cache( | ||
| kv[:chunk_size], | ||
| compressed_k_cache, | ||
|
|
@@ -805,6 +812,7 @@ def _forward_prefill( | |
| block_table=block_table[chunk_start:chunk_end], | ||
| block_size=attn_metadata.block_size // layer.compress_ratio, | ||
| offset=0, | ||
| use_fnuz=False, | ||
| ) | ||
|
|
||
| swa_block_table = swa_metadata.block_table[num_decodes:] | ||
|
|
@@ -816,6 +824,7 @@ def _forward_prefill( | |
| block_table=swa_block_table[chunk_start:chunk_end], | ||
| block_size=swa_metadata.block_size, | ||
| offset=N, | ||
| use_fnuz=current_platform.is_fp8_fnuz(), | ||
| ) | ||
|
|
||
| query_start = ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is failing