Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Triton-based mRoPE, including a new end-to-end test, updates to unit tests, and the core implementation in AscendMRotaryEmbedding. While the feature is a good addition, I've found several critical issues that must be addressed. The new e2e test contains bugs in its reference implementation and test logic, which invalidates its results. More importantly, the AscendMRotaryEmbedding implementation has a critical flaw in its caching logic that will lead to incorrect computations by using stale data. I've also noted a minor issue in the unit test mocks. Please see the detailed comments for specifics and suggestions for fixes.
|
|
||
| if mrope_section[1] > 0: | ||
| cos_row[t_end:h_end] = token_cos[t_end:h_end, 1] | ||
| sin_row[t_end:h_end] = token_cos[t_end:h_end, 1] |
There was a problem hiding this comment.
There appears to be a copy-paste error here. sin_row is being assigned values from token_cos instead of token_sin. This makes the reference implementation incorrect, and the test will validate the Triton kernel against wrong values.
| sin_row[t_end:h_end] = token_cos[t_end:h_end, 1] | |
| sin_row[t_end:h_end] = token_sin[t_end:h_end, 1] |
| q_gold, k_gold = pytorch_forward_native(q_gold, | ||
| k_gold, | ||
| cos, | ||
| sin, | ||
| mrope_section, | ||
| num_tokens, | ||
| head_size, | ||
| True) |
There was a problem hiding this comment.
The arguments passed to pytorch_forward_native are incorrect. The head_size parameter is receiving num_tokens, and the rotary_dim parameter is receiving head_size. The actual rotary_dim from the test parameters is not being used at all. This makes the test logic flawed and it will not correctly validate the Triton kernel's output.
| q_gold, k_gold = pytorch_forward_native(q_gold, | |
| k_gold, | |
| cos, | |
| sin, | |
| mrope_section, | |
| num_tokens, | |
| head_size, | |
| True) | |
| q_gold, k_gold = pytorch_forward_native(q_gold, | |
| k_gold, | |
| cos, | |
| sin, | |
| mrope_section, | |
| head_size, | |
| rotary_dim, | |
| True) |
| self.cos = None | ||
| self.sin = None |
There was a problem hiding this comment.
These attributes are used in forward_triton to cache cos and sin values. However, this caching is stateful and incorrect because the values depend on positions, which can change between calls. This will lead to using stale cached data. These attributes should be removed, and the caching logic in forward_triton should be corrected to be stateless within the forward pass.
| if self.cos is None and self.sin is None: | ||
| cos_sin = self.cos_sin_cache[positions] # type: ignore | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| self.cos = cos.contiguous() | ||
| self.sin = sin.contiguous() | ||
| query_shape = query.shape | ||
| key_shape = key.shape | ||
|
|
||
| assert self.mrope_section | ||
|
|
||
| q, k = triton_mrope( | ||
| query, | ||
| key, | ||
| self.cos, | ||
| self.sin, | ||
| self.mrope_section, | ||
| self.head_size, | ||
| self.rotary_dim, | ||
| self.mrope_interleaved, | ||
| ) |
There was a problem hiding this comment.
The caching logic for cos and sin is incorrect. These tensors depend on positions, which can change on each call to forward_triton. Caching them as instance attributes (self.cos, self.sin) will cause subsequent calls with different positions to use stale values, leading to incorrect results. The cos and sin tensors should be computed on every call and not stored on self.
cos_sin = self.cos_sin_cache[positions] # type: ignore
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
assert self.mrope_section
q, k = triton_mrope(
query,
key,
cos.contiguous(),
sin.contiguous(),
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)| @patch('vllm.triton_utils.HAS_TRITON', True) | ||
| @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) | ||
| @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) | ||
| @patch('vllm.triton_utils.HAS_TRITON', return_value=True) |
There was a problem hiding this comment.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
9e4526f to
0fab5fb
Compare
23d51ed to
fddc350
Compare
| "beta_fast": beta_fast, | ||
| "beta_slow": beta_slow | ||
| } | ||
| super().__init__(head_size, rotary_dim, max_position_embeddings, base, |
0ab48f6 to
e60f393
Compare
Signed-off-by: shiyuan680 <917935075@qq.com>
…to eplb_refactor * 'main' of https://github.com/vllm-project/vllm-ascend: [CI] Unblock 4-cards test (vllm-project#5831) [Refactor] Provide a framework to accommodate operators for different hardware devices (vllm-project#5735) [Refactor] Modify the binding logic to allocate CPU cores for each NPU card (vllm-project#5555) [BugFix] Support setting tp=1 for the Eagle draft model to take effect (vllm-project#5519) support triton of mrope (vllm-project#5664) [bugfix] A2 Environment Pooling for Memcache Compatibility (vllm-project#5601) [Doc] Update community contributors and versioning naming to follow vLLM (vllm-project#5820) [Refactor] Add comments for Metadata classes in attention module (vllm-project#5789) [Bugfix] bugfix for the order of dummy run pad and sync (vllm-project#5777) [CI] Move nightly-a2 test to hk (vllm-project#5807) [CI] Show disk usage for CI shared volume (vllm-project#5821) Bump actions/checkout from 4 to 6 (vllm-project#5795) Bump actions/github-script from 7 to 8 (vllm-project#5796) [bugfix](cp) align max_context_chunk to cp_virtual_block_size (vllm-project#5767) [bugfix]limit graph replay sync (vllm-project#5761) [CI]Add Kimi k2 nightly test (vllm-project#5682) [Doc] add tls check to pd disaggregation readme (vllm-project#5638) [CI] adpat v0.13.0 change (vllm-project#5793)
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Adapted from vllm-ascend PR vllm-project#5664. Adds forward_triton path to AscendMRotaryEmbedding that uses vllm's triton_mrope kernel, with corresponding test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adapted from vllm-ascend PR vllm-project#5664 and PR vllm-project#6042. Adds Triton-based mRoPE support to AscendMRotaryEmbedding, with fix to only use Triton path when mrope_interleaved is True.
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
### What this PR does / why we need it? this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops this triton ops should use cann 8.5.0 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? test in qwen3-vl-235b acc textvqa native 81.82 npu triton 81.58 cuda triton 81.52 - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: shiyuan680 <917935075@qq.com>
What this PR does / why we need it?
this pr support use triton mrope like cuda_forward, which performance is equal to ascendc ops
this triton ops should use cann 8.5.0
Does this PR introduce any user-facing change?
How was this patch tested?
test in qwen3-vl-235b acc textvqa
native 81.82
npu triton 81.58
cuda triton 81.52