-
Notifications
You must be signed in to change notification settings - Fork 1k
Revert "Support lse in trtllm paged attn kernels" #3079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1364,11 +1364,15 @@ def run( | |||||||||||||||||||||||||||||
| if rope_theta is None: | ||||||||||||||||||||||||||||||
| rope_theta = 1e4 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| lse_shape = (q.size(0), q.size(1)) | ||||||||||||||||||||||||||||||
| if lse is not None: | ||||||||||||||||||||||||||||||
| check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") | ||||||||||||||||||||||||||||||
| elif return_lse: | ||||||||||||||||||||||||||||||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) | ||||||||||||||||||||||||||||||
| if return_lse: | ||||||||||||||||||||||||||||||
| if lse is None: | ||||||||||||||||||||||||||||||
| lse = torch.empty( | ||||||||||||||||||||||||||||||
| (q.size(0), q.size(1)), dtype=torch.float32, device=q.device | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| check_shape_dtype_device( | ||||||||||||||||||||||||||||||
| lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if out is None: | ||||||||||||||||||||||||||||||
| out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype | ||||||||||||||||||||||||||||||
|
|
@@ -1959,13 +1963,21 @@ def run( | |||||||||||||||||||||||||||||
| out, q_nope.shape, q_nope.dtype, q_nope.device, "out" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| lse_shape = (q_nope.size(0), q_nope.size(1)) | ||||||||||||||||||||||||||||||
| if lse is not None: | ||||||||||||||||||||||||||||||
| check_shape_dtype_device( | ||||||||||||||||||||||||||||||
| lse, lse_shape, torch.float32, q_nope.device, "lse" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| elif return_lse: | ||||||||||||||||||||||||||||||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=device) | ||||||||||||||||||||||||||||||
| if return_lse: | ||||||||||||||||||||||||||||||
| if lse is None: | ||||||||||||||||||||||||||||||
| lse = torch.empty( | ||||||||||||||||||||||||||||||
| (q_nope.size(0), q_nope.size(1)), | ||||||||||||||||||||||||||||||
| dtype=torch.float32, | ||||||||||||||||||||||||||||||
| device=device, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| check_shape_dtype_device( | ||||||||||||||||||||||||||||||
| lse, | ||||||||||||||||||||||||||||||
| (q_nope.size(0), q_nope.size(1)), | ||||||||||||||||||||||||||||||
| q_nope.dtype, | ||||||||||||||||||||||||||||||
| q_nope.device, | ||||||||||||||||||||||||||||||
| "lse", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
Comment on lines
+1974
to
+1980
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Comment on lines
+1966
to
+1980
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate caller-supplied MLA This branch allocates ♻️ Suggested fix check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
- q_nope.dtype,
+ torch.float32,
q_nope.device,
"lse",
)🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
| self._cached_module.run( | ||||||||||||||||||||||||||||||
| self._float_workspace_buffer, | ||||||||||||||||||||||||||||||
| self._int_workspace_buffer, | ||||||||||||||||||||||||||||||
|
|
@@ -2024,7 +2036,6 @@ def _paged_run( | |||||||||||||||||||||||||||||
| value_block_scales: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx: bool = True, | ||||||||||||||||||||||||||||||
| lse: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| if out is None: | ||||||||||||||||||||||||||||||
| out = torch.empty_like(query) | ||||||||||||||||||||||||||||||
|
|
@@ -2070,7 +2081,6 @@ def _paged_run( | |||||||||||||||||||||||||||||
| value_block_scales, | ||||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx, | ||||||||||||||||||||||||||||||
| lse, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -2137,6 +2147,7 @@ def paged_run( | |||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx: bool = True, | ||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||
| assert maybe_lse is None | ||||||||||||||||||||||||||||||
| assert paged_kv_cache is not None | ||||||||||||||||||||||||||||||
| assert num_qo_heads is not None | ||||||||||||||||||||||||||||||
| assert num_kv_heads is not None | ||||||||||||||||||||||||||||||
|
|
@@ -2165,7 +2176,6 @@ def paged_run( | |||||||||||||||||||||||||||||
| value_block_scales=value_block_scales, | ||||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, | ||||||||||||||||||||||||||||||
| lse=maybe_lse, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @register_fake_op(f"flashinfer::{uri}_paged_run") | ||||||||||||||||||||||||||||||
|
|
@@ -2249,11 +2259,7 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||||||||||||||||||||||||||||||
| kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx: bool = True, | ||||||||||||||||||||||||||||||
| lse: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||
| return_lse: bool = False, | ||||||||||||||||||||||||||||||
| ) -> Union[ | ||||||||||||||||||||||||||||||
| torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] | ||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||
| ) -> Union[torch.Tensor, FP4Tensor]: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||
|
|
@@ -2373,19 +2379,10 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||
| True (default) uses vLLM/FlashInfer layout with a 2D page table. | ||||||||||||||||||||||||||||||
| False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| lse: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||
| The log-sum-exp of attention logits, if not provided, will be allocated internally. | ||||||||||||||||||||||||||||||
| Only supported by trtllm-gen backend. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return_lse: bool = False | ||||||||||||||||||||||||||||||
| Whether to return the logsumexp of attention scores, defaults to ``False``. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||
| out : Union[torch.Tensor, FP4Tensor] | ||||||||||||||||||||||||||||||
| output torch.Tensor or FP4Tensor. | ||||||||||||||||||||||||||||||
| lse: Optional[torch.Tensor] | ||||||||||||||||||||||||||||||
| The log-sum-exp of attention logits, if not provided, will be allocated internally. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -2433,11 +2430,6 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||
| backend = ( | ||||||||||||||||||||||||||||||
| "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| wants_lse = return_lse or lse is not None | ||||||||||||||||||||||||||||||
| if wants_lse and backend != "trtllm-gen": | ||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||
| "lse and return_lse are only supported by the trtllm-gen backend" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if backend == "xqa": | ||||||||||||||||||||||||||||||
| # xqa backend doesn't support nvfp4 output | ||||||||||||||||||||||||||||||
|
|
@@ -2586,12 +2578,6 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| lse_shape = (query.size(0), query.size(1)) | ||||||||||||||||||||||||||||||
| if lse is not None: | ||||||||||||||||||||||||||||||
| check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") | ||||||||||||||||||||||||||||||
| elif return_lse: | ||||||||||||||||||||||||||||||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| run_func( | ||||||||||||||||||||||||||||||
| out, | ||||||||||||||||||||||||||||||
| out_scale_factor, | ||||||||||||||||||||||||||||||
|
|
@@ -2620,18 +2606,13 @@ def trtllm_batch_decode_with_kv_cache( | |||||||||||||||||||||||||||||
| v_block_scales, | ||||||||||||||||||||||||||||||
| skip_softmax_threshold_scale_factor, | ||||||||||||||||||||||||||||||
| uses_shared_paged_kv_idx, | ||||||||||||||||||||||||||||||
| lse, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| out = ( | ||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||
| out | ||||||||||||||||||||||||||||||
| if out_dtype != "nvfp4" | ||||||||||||||||||||||||||||||
| else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| if return_lse: | ||||||||||||||||||||||||||||||
| return out, lse | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| raise KeyError(f"Backend {backend} not supported") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: flashinfer-ai/flashinfer
Length of output: 16188
🏁 Script executed:
Repository: flashinfer-ai/flashinfer
Length of output: 1599
🏁 Script executed:
Repository: flashinfer-ai/flashinfer
Length of output: 4692
🏁 Script executed:
Repository: flashinfer-ai/flashinfer
Length of output: 975
Block
lse/return_lseon the TRT-LLM decode wrapper instead of asserting internally.The public wrapper accepts
return_lse=Trueand explicitlsetensors without checking the backend, but passes them to the custom op'spaged_run()which assertsmaybe_lse is None. This causes anAssertionErrorinstead of a stable user-facing error, and fails silently underpython -O.♻️ Suggested fix
🤖 Prompt for AI Agents