diff --git a/README.md b/README.md index cd5c7e1e58..6c16c2bbab 100644 --- a/README.md +++ b/README.md @@ -181,7 +181,7 @@ export FLASHINFER_LOGLEVEL=3 export FLASHINFER_LOGDEST=stdout ``` -For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md). +For detailed information about logging levels, configuration, and advanced features, see [Logging](https://docs.flashinfer.ai/logging.html) in our documentation. ## Custom Attention Variants diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 67e763fefa..35abb2fdba 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -20,6 +20,7 @@ import torch +from .api_logging import flashinfer_api from .jit import gen_act_and_mul_module from .utils import ( device_support_pdl, @@ -66,6 +67,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: ) +@flashinfer_api def silu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -110,6 +112,7 @@ def silu_and_mul( return out +@flashinfer_api def gelu_tanh_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -150,6 +153,7 @@ def gelu_tanh_and_mul( return out +@flashinfer_api def gelu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -190,6 +194,7 @@ def gelu_and_mul( return out +@flashinfer_api def silu_and_mul_scaled_nvfp4_experts_quantize( a, mask, diff --git a/flashinfer/attention.py b/flashinfer/attention.py index b1e288b903..c4bc4f27dc 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -20,6 +20,7 @@ import torch +from .api_logging import flashinfer_api from .jit import gen_batch_attention_module from .utils import ( MaskMode, @@ -40,6 +41,7 @@ def get_holistic_attention_module(*args): class BatchAttention: + @flashinfer_api def __init__( self, kv_layout: str = "NHD", @@ -65,6 +67,7 @@ def __init__( pin_memory=True, ) + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -132,6 +135,7 @@ def plan( causal, ) + @flashinfer_api def run( self, q: torch.Tensor, diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 267f0d2990..1de363bb37 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .decode import BatchDecodeWithPagedKVCacheWrapper from .jit.cascade import gen_cascade_module from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache @@ -30,6 +31,7 @@ def get_cascade_module(): return gen_cascade_module().build_and_load() +@flashinfer_api @register_custom_op("flashinfer::merge_state", mutates_args=()) def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor @@ -96,6 +98,7 @@ def _fake_merge_state( return v, s +@flashinfer_api @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s")) def merge_state_in_place( v: torch.Tensor, @@ -156,6 +159,7 @@ def _fake_merge_state_in_place( pass +@flashinfer_api @register_custom_op("flashinfer::merge_states", mutates_args=()) def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Merge multiple attention states (v, s). @@ -287,6 +291,7 @@ class MultiLevelCascadeAttentionWrapper: BatchPrefillWithPagedKVCacheWrapper """ + @flashinfer_api def __init__( self, num_levels, @@ -386,6 +391,7 @@ def reset_workspace_buffer( ): wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer) + @flashinfer_api def plan( self, qo_indptr_arr: List[torch.Tensor], @@ -506,6 +512,7 @@ def plan( begin_forward = plan + @flashinfer_api def run( self, q: torch.Tensor, @@ -629,6 +636,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" ) -> None: @@ -656,6 +664,7 @@ def reset_workspace_buffer( float_workspace_buffer, int_workspace_buffer ) + @flashinfer_api def begin_forward( self, unique_kv_indptr: torch.Tensor, @@ -717,6 +726,7 @@ def begin_forward( data_type=data_type, ) + @flashinfer_api def forward( self, q: torch.Tensor, @@ -780,6 +790,7 @@ def forward( merge_state_in_place(V_shared, S_shared, V_unique, S_unique) return V_shared + @flashinfer_api def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect""" pass @@ -876,6 +887,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: layers). This wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" ) -> None: @@ -914,6 +926,7 @@ def reset_workspace_buffer( float_workspace_buffer, int_workspace_buffer ) + @flashinfer_api def begin_forward( self, qo_indptr: torch.Tensor, @@ -969,6 +982,7 @@ def begin_forward( page_size, ) + @flashinfer_api def forward( self, q: torch.Tensor, @@ -1060,6 +1074,7 @@ def forward( merge_state_in_place(V_shared, S_shared, V_unique, S_unique) return V_shared + @flashinfer_api def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect""" pass diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 3f9f03ebb7..0765f933df 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1501,6 +1501,7 @@ class BatchDecodeMlaWithPagedKVCacheWrapper: a more efficient and general MLA implementation that supports decode and incremental prefill. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1615,6 +1616,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -1740,6 +1742,7 @@ def plan( self._rope_scale = rope_scale self._rope_theta = rope_theta + @flashinfer_api def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 29127f06ac..0d08123ca7 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -21,6 +21,7 @@ import torch +from .api_logging import flashinfer_api from .jit import JitSpec from .jit import env as jit_env from .jit import ( @@ -624,6 +625,7 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100( ) +@flashinfer_api def fp4_quantize( input: torch.Tensor, global_scale: Optional[torch.Tensor] = None, @@ -689,6 +691,7 @@ def fp4_quantize( return x_q, sf +@flashinfer_api def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: """Swizzle block scale tensor for FP4 format. @@ -721,6 +724,7 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: nvfp4_block_scale_interleave = block_scale_interleave +@flashinfer_api def e2m1_and_ufp8sf_scale_to_float( e2m1_tensor: torch.Tensor, ufp8_scale_tensor: torch.Tensor, @@ -763,6 +767,7 @@ def e2m1_and_ufp8sf_scale_to_float( ) +@flashinfer_api def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: """ PyTorch equivalent of trtllm-gen `shuffleMatrixA` @@ -772,6 +777,7 @@ def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch. return input_tensor[row_indices.to(input_tensor.device)] +@flashinfer_api def shuffle_matrix_sf_a( input_tensor: torch.Tensor, epilogue_tile_m: int, @@ -806,6 +812,7 @@ class SfLayout(Enum): layout_linear = 2 +@flashinfer_api def nvfp4_quantize( a, a_global_sf, @@ -866,6 +873,7 @@ def nvfp4_quantize( return a_fp4, a_sf +@flashinfer_api def mxfp4_quantize(a): """ Quantize input tensor to MXFP4 format. @@ -883,6 +891,7 @@ def mxfp4_quantize(a): return a_fp4, a_sf +@flashinfer_api def mxfp4_dequantize(a_fp4, a_sf): """ Dequantize input tensor from MXFP4 format. @@ -904,6 +913,7 @@ def mxfp4_dequantize(a_fp4, a_sf): ) +@flashinfer_api def mxfp4_dequantize_host( weight: torch.Tensor, scale: torch.Tensor, @@ -932,6 +942,7 @@ def mxfp4_dequantize_host( ) +@flashinfer_api def nvfp4_batched_quantize( a, a_global_sf, @@ -961,6 +972,7 @@ def nvfp4_batched_quantize( return a_fp4, a_sf +@flashinfer_api def scaled_fp4_grouped_quantize( a, mask, diff --git a/flashinfer/fp8_quantization.py b/flashinfer/fp8_quantization.py index 07db59b681..1d2cdeea76 100644 --- a/flashinfer/fp8_quantization.py +++ b/flashinfer/fp8_quantization.py @@ -4,6 +4,7 @@ import torch +from .api_logging import flashinfer_api from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module from .utils import ( device_support_pdl, @@ -142,6 +143,7 @@ def _fake_mxfp8_dequantize_host_sm100( ) +@flashinfer_api def mxfp8_quantize( input: torch.Tensor, is_sf_swizzled_layout: bool = True, @@ -178,6 +180,7 @@ def mxfp8_quantize( return x_q, sf +@flashinfer_api def mxfp8_dequantize_host( input: torch.Tensor, scale_tensor: torch.Tensor, diff --git a/flashinfer/gemm/routergemm_dsv3.py b/flashinfer/gemm/routergemm_dsv3.py index 05415ec61f..c82ccf3fcd 100644 --- a/flashinfer/gemm/routergemm_dsv3.py +++ b/flashinfer/gemm/routergemm_dsv3.py @@ -1,3 +1,4 @@ +from ..api_logging import flashinfer_api from flashinfer.jit import gen_dsv3_router_gemm_module import functools from types import SimpleNamespace @@ -85,6 +86,7 @@ def mm_M1_16_K7168_N256( ) +@flashinfer_api @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks) def mm_M1_16_K7168_N256( mat_a: torch.Tensor, diff --git a/flashinfer/norm.py b/flashinfer/norm.py index 7318974215..5a697186a6 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit.norm import gen_norm_module from .utils import device_support_pdl, register_custom_op, register_fake_op @@ -28,6 +29,7 @@ def get_norm_module(): return gen_norm_module().build_and_load() +@flashinfer_api def rmsnorm( input: torch.Tensor, weight: torch.Tensor, @@ -90,6 +92,7 @@ def _rmsnorm_fake( pass +@flashinfer_api @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual")) def fused_add_rmsnorm( input: torch.Tensor, @@ -136,6 +139,7 @@ def _fused_add_rmsnorm_fake( pass +@flashinfer_api def gemma_rmsnorm( input: torch.Tensor, weight: torch.Tensor, @@ -198,6 +202,7 @@ def _gemma_rmsnorm_fake( pass +@flashinfer_api @register_custom_op( "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual") ) @@ -246,6 +251,7 @@ def _gemma_fused_add_rmsnorm_fake( pass +@flashinfer_api @register_custom_op("flashinfer::layernorm", mutates_args=()) def layernorm( input: torch.Tensor, diff --git a/flashinfer/page.py b/flashinfer/page.py index 069303e501..ba1f46a05b 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit.page import gen_page_module from .utils import ( TensorLayout, @@ -154,6 +155,7 @@ def _fake_append_paged_kv_cache_kernel( pass +@flashinfer_api def get_batch_indices_positions( append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -235,6 +237,7 @@ def get_seq_lens( ) +@flashinfer_api def append_paged_mla_kv_cache( append_ckv: torch.Tensor, append_kpe: torch.Tensor, @@ -284,6 +287,7 @@ def append_paged_mla_kv_cache( ) +@flashinfer_api def append_paged_kv_cache( append_key: torch.Tensor, append_value: torch.Tensor, diff --git a/flashinfer/pod.py b/flashinfer/pod.py index d0a66f7ae9..fe2e36c1ef 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -21,6 +21,7 @@ import torch +from .api_logging import flashinfer_api from .jit import gen_pod_module, gen_batch_pod_module from .page import get_seq_lens from .prefill import get_batch_prefill_module @@ -119,6 +120,7 @@ class PODWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -259,6 +261,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -432,6 +435,7 @@ def plan( begin_forward = plan + @flashinfer_api def run( self, # Main params (prefill and decode) @@ -724,6 +728,7 @@ class BatchPODWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -791,6 +796,7 @@ def __init__( def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph + @flashinfer_api def plan( self, qo_indptr_p: torch.Tensor, @@ -1009,6 +1015,7 @@ def plan( begin_forward = plan + @flashinfer_api def run( self, # Main params (prefill and decode) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 1fa935388a..5b8140ec48 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3554,6 +3554,7 @@ def trtllm_batch_context_with_kv_cache( ) +@flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, key: torch.Tensor, diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index 810b1f2ae1..4e279ab5f0 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit.quantization import gen_quantization_module from .utils import register_custom_op, register_fake_op @@ -42,6 +43,7 @@ def _fake_packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device) +@flashinfer_api def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: r"""Pack the elements of a binary-valued array into bits in a uint8 array. @@ -76,6 +78,7 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: return _packbits(x, bitorder) +@flashinfer_api def segment_packbits( x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big" ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/flashinfer/rope.py b/flashinfer/rope.py index dea6995bcf..1d069e3189 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit.rope import gen_rope_module from .utils import register_custom_op, register_fake_op @@ -413,6 +414,7 @@ def _fake_apply_llama31_rope_pos_ids( pass +@flashinfer_api def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -500,6 +502,7 @@ def apply_rope_inplace( ) +@flashinfer_api def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, @@ -558,6 +561,7 @@ def apply_rope_pos_ids_inplace( ) +@flashinfer_api def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -666,6 +670,7 @@ def apply_llama31_rope_inplace( ) +@flashinfer_api def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, @@ -744,6 +749,7 @@ def apply_llama31_rope_pos_ids_inplace( ) +@flashinfer_api def apply_rope( q: torch.Tensor, k: torch.Tensor, @@ -854,6 +860,7 @@ def apply_rope( return q_rope, k_rope +@flashinfer_api def apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, @@ -922,6 +929,7 @@ def apply_rope_pos_ids( return q_rope, k_rope +@flashinfer_api def apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, @@ -1044,6 +1052,7 @@ def apply_llama31_rope( return q_rope, k_rope +@flashinfer_api def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, @@ -1131,6 +1140,7 @@ def apply_llama31_rope_pos_ids( return q_rope, k_rope +@flashinfer_api def apply_rope_with_cos_sin_cache( positions: torch.Tensor, query: torch.Tensor, @@ -1194,6 +1204,7 @@ def apply_rope_with_cos_sin_cache( return query_out, key_out +@flashinfer_api def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, @@ -1246,6 +1257,7 @@ def apply_rope_with_cos_sin_cache_inplace( ) +@flashinfer_api def mla_rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, @@ -1282,6 +1294,7 @@ def mla_rope_quantize_fp8( ) +@flashinfer_api def rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, @@ -1421,6 +1434,7 @@ def rope_quantize_fp8( return q_rope_out, k_rope_out, q_nope_out, k_nope_out +@flashinfer_api def rope_quantize_fp8_append_paged_kv_cache( q_rope: torch.Tensor, k_rope: torch.Tensor, diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 9f80e3c926..ff437efd33 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -19,6 +19,7 @@ from typing import Any, Optional, Tuple, Union import torch +from .api_logging import flashinfer_api from .jit.sampling import gen_sampling_module from .utils import ( _get_cache_buf, @@ -529,6 +530,7 @@ def _check_tensor_param(param: Any, tensor: torch.Tensor) -> None: ) +@flashinfer_api def softmax( logits: torch.Tensor, temperature: Optional[Union[torch.Tensor, float]] = None, @@ -586,6 +588,7 @@ def softmax( ) +@flashinfer_api def sampling_from_logits( logits: torch.Tensor, indices: Optional[torch.Tensor] = None, @@ -651,6 +654,7 @@ def sampling_from_logits( ) +@flashinfer_api def sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor] = None, @@ -722,6 +726,7 @@ def sampling_from_probs( ) +@flashinfer_api def top_p_sampling_from_probs( probs: torch.Tensor, top_p: Union[torch.Tensor, float], @@ -818,6 +823,7 @@ def top_p_sampling_from_probs( ) +@flashinfer_api def top_k_sampling_from_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], @@ -914,6 +920,7 @@ def top_k_sampling_from_probs( ) +@flashinfer_api def min_p_sampling_from_probs( probs: torch.Tensor, min_p: Union[torch.Tensor, float], @@ -1006,6 +1013,7 @@ def min_p_sampling_from_probs( ) +@flashinfer_api def top_k_top_p_sampling_from_logits( logits: torch.Tensor, top_k: Union[torch.Tensor, int], @@ -1139,6 +1147,7 @@ def top_k_top_p_sampling_from_logits( raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") +@flashinfer_api def top_k_top_p_sampling_from_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], @@ -1265,6 +1274,7 @@ def top_k_top_p_sampling_from_probs( raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") +@flashinfer_api def top_p_renorm_probs( probs: torch.Tensor, top_p: Union[torch.Tensor, float], @@ -1331,6 +1341,7 @@ def top_p_renorm_probs( top_p_renorm_prob = top_p_renorm_probs +@flashinfer_api def top_k_renorm_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], @@ -1396,6 +1407,7 @@ def top_k_renorm_probs( top_k_renorm_prob = top_k_renorm_probs +@flashinfer_api def top_k_mask_logits( logits: torch.Tensor, top_k: Union[torch.Tensor, int] ) -> torch.Tensor: @@ -1453,6 +1465,7 @@ def top_k_mask_logits( ) +@flashinfer_api def chain_speculative_sampling( draft_probs, draft_token_ids, diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 37a3d444b7..05622f7410 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .decode import get_batch_decode_module from .page import block_sparse_indices_to_vector_sparse_offsets from .prefill import _compute_page_mask_indptr, get_batch_prefill_module @@ -107,6 +108,7 @@ class BlockSparseAttentionWrapper: True """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -203,6 +205,7 @@ def reset_workspace_buffer( if vector_sparse_indptr_buffer is not None: self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -511,6 +514,7 @@ def forward( self._rope_theta = rope_theta return self.run(q, k, v, scale_q, scale_k, scale_v) + @flashinfer_api def run( self, q: torch.Tensor, @@ -735,6 +739,7 @@ class VariableBlockSparseAttentionWrapper: >>> o = wrapper.run(q, k, v) """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -822,6 +827,7 @@ def reset_workspace_buffer( if vector_sparse_indptr_buffer is not None: self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer + @flashinfer_api def plan( self, block_mask_map: torch.Tensor, @@ -1098,6 +1104,7 @@ def forward( self._rope_theta = rope_theta return self.run(q, k, v) + @flashinfer_api def run( self, q: torch.Tensor, diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 88f96425d5..78104aa9a1 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -19,6 +19,7 @@ from typing import Optional, Union import torch +from .api_logging import flashinfer_api from .jit.xqa import gen_xqa_module, gen_xqa_module_mla from .jit.utils import filename_safe_dtype_map from .utils import ( @@ -143,6 +144,7 @@ def _fake_xqa( ) +@flashinfer_api def xqa( q: torch.Tensor, k_cache: torch.Tensor, @@ -414,6 +416,7 @@ def _fake_xqa_mla( ) +@flashinfer_api def xqa_mla( q: torch.Tensor, k_cache: torch.Tensor,