diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 734d6bae28..d3ab73b3ad 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -469,8 +469,6 @@ def flashinfer_api(func: Callable = None) -> Callable: This decorator integrates with Python's standard logging infrastructure while maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). - NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress. - Environment Variables --------------------- FLASHINFER_LOGLEVEL : int (default: 0) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 237ef37102..a24745a076 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -53,6 +53,8 @@ import torch +from flashinfer.api_logging import flashinfer_api + from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata @@ -270,6 +272,7 @@ def _workspace_creation_heuristic( # ============================================================================ +@flashinfer_api def create_allreduce_fusion_workspace( backend: Literal["trtllm", "mnnvl", "auto"] = "auto", world_size: int = None, @@ -440,6 +443,7 @@ def create_allreduce_fusion_workspace( # ============================================================================ +@flashinfer_api def allreduce_fusion( input: torch.Tensor, workspace: AllReduceFusionWorkspace, diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index 78a888dac6..d1843a33aa 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -55,6 +55,7 @@ ) from cutlass._mlir.dialects import llvm from flashinfer.utils import get_compute_capability +from flashinfer.api_logging import flashinfer_api from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr from typing import Callable, List @@ -2942,6 +2943,7 @@ def tensor_api( return tensor_api +@flashinfer_api def grouped_gemm_nt_masked( lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py index 9c6fb79c91..5e26ca30cf 100644 --- a/flashinfer/fused_moe/fused_routing_dsv3.py +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -1,3 +1,4 @@ +from flashinfer.api_logging import flashinfer_api from flashinfer.jit import gen_dsv3_fused_routing_module import functools from types import SimpleNamespace @@ -116,6 +117,7 @@ def NoAuxTc( @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) +@flashinfer_api def fused_topk_deepseek( scores: torch.Tensor, bias: torch.Tensor, diff --git a/flashinfer/gemm/routergemm_dsv3.py b/flashinfer/gemm/routergemm_dsv3.py index c82ccf3fcd..3a3e1c93d4 100644 --- a/flashinfer/gemm/routergemm_dsv3.py +++ b/flashinfer/gemm/routergemm_dsv3.py @@ -86,8 +86,8 @@ def mm_M1_16_K7168_N256( ) -@flashinfer_api @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks) +@flashinfer_api def mm_M1_16_K7168_N256( mat_a: torch.Tensor, mat_b: torch.Tensor, diff --git a/flashinfer/topk.py b/flashinfer/topk.py index 4c1c01cf23..d1cfb754be 100644 --- a/flashinfer/topk.py +++ b/flashinfer/topk.py @@ -152,6 +152,7 @@ def can_implement_filtered_topk() -> bool: return get_topk_module().can_implement_filtered_topk() +@flashinfer_api def top_k( input: torch.Tensor, k: int, diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py index 2d69bc1e98..e7beb51343 100644 --- a/flashinfer/trtllm_low_latency_gemm.py +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -19,6 +19,7 @@ import functools +from flashinfer.api_logging import flashinfer_api from flashinfer.fused_moe.core import ( convert_to_block_layout, get_w2_permute_indices_with_cache, @@ -193,6 +194,7 @@ def trtllm_low_latency_gemm( return out +@flashinfer_api def prepare_low_latency_gemm_weights( w: torch.Tensor, permutation_indices_cache: Dict[torch.Size, torch.Tensor] ) -> torch.Tensor: