Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ export FLASHINFER_LOGLEVEL=3
export FLASHINFER_LOGDEST=stdout
```

For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md).
For detailed information about logging levels, configuration, and advanced features, see [Logging](https://docs.flashinfer.ai/logging.html) in our documentation.

## Custom Attention Variants

Expand Down
5 changes: 5 additions & 0 deletions flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch

from .api_logging import flashinfer_api
from .jit import gen_act_and_mul_module
from .utils import (
device_support_pdl,
Expand Down Expand Up @@ -66,6 +67,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
)


@flashinfer_api
def silu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -110,6 +112,7 @@ def silu_and_mul(
return out


@flashinfer_api
def gelu_tanh_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -150,6 +153,7 @@ def gelu_tanh_and_mul(
return out


@flashinfer_api
def gelu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -190,6 +194,7 @@ def gelu_and_mul(
return out


@flashinfer_api
def silu_and_mul_scaled_nvfp4_experts_quantize(
a,
mask,
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch

from .api_logging import flashinfer_api
from .jit import gen_batch_attention_module
from .utils import (
MaskMode,
Expand All @@ -40,6 +41,7 @@ def get_holistic_attention_module(*args):


class BatchAttention:
@flashinfer_api
def __init__(
self,
kv_layout: str = "NHD",
Expand All @@ -65,6 +67,7 @@ def __init__(
pin_memory=True,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Thanks for adding the decorator to the plan and run methods. For consistency with other wrapper classes in this PR (e.g., MultiLevelCascadeAttentionWrapper), could you also add the @flashinfer_api decorator to the __init__ method of the BatchAttention class? It appears to be a public API that should be logged.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@flashinfer_api
def plan(
self,
qo_indptr: torch.Tensor,
Expand Down Expand Up @@ -132,6 +135,7 @@ def plan(
causal,
)

@flashinfer_api
def run(
self,
q: torch.Tensor,
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch

from .api_logging import flashinfer_api
from .decode import BatchDecodeWithPagedKVCacheWrapper
from .jit.cascade import gen_cascade_module
from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache
Expand All @@ -30,6 +31,7 @@ def get_cascade_module():
return gen_cascade_module().build_and_load()


@flashinfer_api
@register_custom_op("flashinfer::merge_state", mutates_args=())
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
Expand Down Expand Up @@ -96,6 +98,7 @@ def _fake_merge_state(
return v, s


@flashinfer_api
@register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
def merge_state_in_place(
v: torch.Tensor,
Expand Down Expand Up @@ -156,6 +159,7 @@ def _fake_merge_state_in_place(
pass


@flashinfer_api
@register_custom_op("flashinfer::merge_states", mutates_args=())
def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Merge multiple attention states (v, s).
Expand Down Expand Up @@ -287,6 +291,7 @@ class MultiLevelCascadeAttentionWrapper:
BatchPrefillWithPagedKVCacheWrapper
"""

@flashinfer_api
def __init__(
self,
num_levels,
Expand Down Expand Up @@ -386,6 +391,7 @@ def reset_workspace_buffer(
):
wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer)

@flashinfer_api
def plan(
self,
qo_indptr_arr: List[torch.Tensor],
Expand Down Expand Up @@ -506,6 +512,7 @@ def plan(

begin_forward = plan

@flashinfer_api
def run(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -629,6 +636,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
manages the lifecycle of these data structures.
"""

@flashinfer_api
def __init__(
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
) -> None:
Expand Down Expand Up @@ -656,6 +664,7 @@ def reset_workspace_buffer(
float_workspace_buffer, int_workspace_buffer
)

@flashinfer_api
def begin_forward(
self,
unique_kv_indptr: torch.Tensor,
Expand Down Expand Up @@ -717,6 +726,7 @@ def begin_forward(
data_type=data_type,
)

@flashinfer_api
def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -780,6 +790,7 @@ def forward(
merge_state_in_place(V_shared, S_shared, V_unique, S_unique)
return V_shared

@flashinfer_api
def end_forward(self) -> None:
r"""Warning: this function is deprecated and has no effect"""
pass
Expand Down Expand Up @@ -876,6 +887,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
layers). This wrapper class manages the lifecycle of these data structures.
"""

@flashinfer_api
def __init__(
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
) -> None:
Expand Down Expand Up @@ -914,6 +926,7 @@ def reset_workspace_buffer(
float_workspace_buffer, int_workspace_buffer
)

@flashinfer_api
def begin_forward(
self,
qo_indptr: torch.Tensor,
Expand Down Expand Up @@ -969,6 +982,7 @@ def begin_forward(
page_size,
)

@flashinfer_api
def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -1060,6 +1074,7 @@ def forward(
merge_state_in_place(V_shared, S_shared, V_unique, S_unique)
return V_shared

@flashinfer_api
def end_forward(self) -> None:
r"""Warning: this function is deprecated and has no effect"""
pass
3 changes: 3 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,7 @@ class BatchDecodeMlaWithPagedKVCacheWrapper:
a more efficient and general MLA implementation that supports decode and incremental prefill.
"""

@flashinfer_api
def __init__(
self,
float_workspace_buffer: torch.Tensor,
Expand Down Expand Up @@ -1615,6 +1616,7 @@ def reset_workspace_buffer(
pin_memory=True,
)

@flashinfer_api
def plan(
self,
indptr: torch.Tensor,
Expand Down Expand Up @@ -1740,6 +1742,7 @@ def plan(
self._rope_scale = rope_scale
self._rope_theta = rope_theta

@flashinfer_api
def run(
self,
q_nope: torch.Tensor,
Expand Down
12 changes: 12 additions & 0 deletions flashinfer/fp4_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch

from .api_logging import flashinfer_api
from .jit import JitSpec
from .jit import env as jit_env
from .jit import (
Expand Down Expand Up @@ -624,6 +625,7 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100(
)


@flashinfer_api
def fp4_quantize(
input: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -689,6 +691,7 @@ def fp4_quantize(
return x_q, sf


@flashinfer_api
def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
"""Swizzle block scale tensor for FP4 format.

Expand Down Expand Up @@ -721,6 +724,7 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
nvfp4_block_scale_interleave = block_scale_interleave


@flashinfer_api
def e2m1_and_ufp8sf_scale_to_float(
e2m1_tensor: torch.Tensor,
ufp8_scale_tensor: torch.Tensor,
Expand Down Expand Up @@ -763,6 +767,7 @@ def e2m1_and_ufp8sf_scale_to_float(
)


@flashinfer_api
def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
"""
PyTorch equivalent of trtllm-gen `shuffleMatrixA`
Expand All @@ -772,6 +777,7 @@ def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.
return input_tensor[row_indices.to(input_tensor.device)]


@flashinfer_api
def shuffle_matrix_sf_a(
input_tensor: torch.Tensor,
epilogue_tile_m: int,
Expand Down Expand Up @@ -806,6 +812,7 @@ class SfLayout(Enum):
layout_linear = 2


@flashinfer_api
def nvfp4_quantize(
a,
a_global_sf,
Expand Down Expand Up @@ -866,6 +873,7 @@ def nvfp4_quantize(
return a_fp4, a_sf


@flashinfer_api
def mxfp4_quantize(a):
"""
Quantize input tensor to MXFP4 format.
Expand All @@ -883,6 +891,7 @@ def mxfp4_quantize(a):
return a_fp4, a_sf


@flashinfer_api
def mxfp4_dequantize(a_fp4, a_sf):
"""
Dequantize input tensor from MXFP4 format.
Expand All @@ -904,6 +913,7 @@ def mxfp4_dequantize(a_fp4, a_sf):
)


@flashinfer_api
def mxfp4_dequantize_host(
weight: torch.Tensor,
scale: torch.Tensor,
Expand Down Expand Up @@ -932,6 +942,7 @@ def mxfp4_dequantize_host(
)


@flashinfer_api
def nvfp4_batched_quantize(
a,
a_global_sf,
Expand Down Expand Up @@ -961,6 +972,7 @@ def nvfp4_batched_quantize(
return a_fp4, a_sf


@flashinfer_api
def scaled_fp4_grouped_quantize(
a,
mask,
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/fp8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from .api_logging import flashinfer_api
from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
from .utils import (
device_support_pdl,
Expand Down Expand Up @@ -142,6 +143,7 @@ def _fake_mxfp8_dequantize_host_sm100(
)


@flashinfer_api
def mxfp8_quantize(
input: torch.Tensor,
is_sf_swizzled_layout: bool = True,
Expand Down Expand Up @@ -178,6 +180,7 @@ def mxfp8_quantize(
return x_q, sf


@flashinfer_api
def mxfp8_dequantize_host(
input: torch.Tensor,
scale_tensor: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/gemm/routergemm_dsv3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..api_logging import flashinfer_api
from flashinfer.jit import gen_dsv3_router_gemm_module
import functools
from types import SimpleNamespace
Expand Down Expand Up @@ -85,6 +86,7 @@ def mm_M1_16_K7168_N256(
)


@flashinfer_api
@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
def mm_M1_16_K7168_N256(
mat_a: torch.Tensor,
Expand Down
Loading