Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions flashinfer/api_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,6 @@ def flashinfer_api(func: Callable = None) -> Callable:
This decorator integrates with Python's standard logging infrastructure while
maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0).

NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress.

Environment Variables
---------------------
FLASHINFER_LOGLEVEL : int (default: 0)
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

import torch

from flashinfer.api_logging import flashinfer_api

from .trtllm_ar import trtllm_allreduce_fusion
from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata
Expand Down Expand Up @@ -270,6 +272,7 @@ def _workspace_creation_heuristic(
# ============================================================================


@flashinfer_api
def create_allreduce_fusion_workspace(
backend: Literal["trtllm", "mnnvl", "auto"] = "auto",
world_size: int = None,
Expand Down Expand Up @@ -440,6 +443,7 @@ def create_allreduce_fusion_workspace(
# ============================================================================


@flashinfer_api
def allreduce_fusion(
input: torch.Tensor,
workspace: AllReduceFusionWorkspace,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/cute_dsl/blockscaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from cutlass._mlir.dialects import llvm
from flashinfer.utils import get_compute_capability
from flashinfer.api_logging import flashinfer_api
from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo
from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr
from typing import Callable, List
Expand Down Expand Up @@ -2942,6 +2943,7 @@ def tensor_api(
return tensor_api


@flashinfer_api
def grouped_gemm_nt_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/fused_routing_dsv3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flashinfer.api_logging import flashinfer_api
from flashinfer.jit import gen_dsv3_fused_routing_module
import functools
from types import SimpleNamespace
Expand Down Expand Up @@ -116,6 +117,7 @@ def NoAuxTc(


@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
@flashinfer_api
def fused_topk_deepseek(
scores: torch.Tensor,
bias: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/gemm/routergemm_dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def mm_M1_16_K7168_N256(
)


@flashinfer_api
@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
@flashinfer_api
def mm_M1_16_K7168_N256(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions flashinfer/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def can_implement_filtered_topk() -> bool:
return get_topk_module().can_implement_filtered_topk()


@flashinfer_api
def top_k(
input: torch.Tensor,
k: int,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/trtllm_low_latency_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import functools

from flashinfer.api_logging import flashinfer_api
from flashinfer.fused_moe.core import (
convert_to_block_layout,
get_w2_permute_indices_with_cache,
Expand Down Expand Up @@ -193,6 +194,7 @@ def trtllm_low_latency_gemm(
return out


@flashinfer_api
def prepare_low_latency_gemm_weights(
w: torch.Tensor, permutation_indices_cache: Dict[torch.Size, torch.Tensor]
) -> torch.Tensor:
Expand Down