Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/api/attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Batch Decoding

cudnn_batch_decode_with_kv_cache
trtllm_batch_decode_with_kv_cache
fast_decode_plan

.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:
Expand Down Expand Up @@ -108,3 +109,21 @@ PageAttention for MLA
:members:

.. automethod:: __init__


flashinfer.pod
==============

POD (Prefix-Only Decode) attention kernels for efficient prefix caching.

.. currentmodule:: flashinfer.pod

.. autoclass:: PODWithPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: BatchPODWithPagedKVCacheWrapper
:members:

.. automethod:: __init__
22 changes: 21 additions & 1 deletion docs/api/comm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Types and Enums
AllReduceFusionPattern
AllReduceStrategyConfig
AllReduceStrategyType
FP4QuantizationSFLayout
QuantizationSFLayout

Core Operations
~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -94,6 +94,26 @@ vLLM AllReduce
vllm_get_graph_buffer_ipc_meta
vllm_meta_size

Unified AllReduce Fusion API
-----------------------------

.. autosummary::
:toctree: ../generated

AllReduceFusionWorkspace
TRTLLMAllReduceFusionWorkspace
allreduce_fusion
create_allreduce_fusion_workspace

.. currentmodule:: flashinfer.comm.trtllm_mnnvl_ar

.. autosummary::
:toctree: ../generated

MNNVLAllReduceFusionWorkspace

.. currentmodule:: flashinfer.comm

MNNVL (Multi-Node NVLink)
-------------------------

Expand Down
30 changes: 30 additions & 0 deletions docs/api/gemm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ FP8 GEMM
:toctree: ../generated

bmm_fp8
mm_fp8
gemm_fp8_nt_groupwise
group_gemm_fp8_nt_groupwise
group_deepgemm_fp8_nt_groupwise
Expand All @@ -43,3 +44,32 @@ Grouped GEMM (Ampere/Hopper)
:exclude-members: forward

.. automethod:: __init__

Blackwell GEMM
--------------

.. autosummary::
:toctree: ../generated

tgv_gemm_sm100

TensorRT-LLM Low Latency GEMM
------------------------------

.. currentmodule:: flashinfer.trtllm_low_latency_gemm

.. autosummary::
:toctree: ../generated

prepare_low_latency_gemm_weights

CuTe-DSL GEMM
-------------

.. currentmodule:: flashinfer.gemm

.. autosummary::
:toctree: ../generated

grouped_gemm_nt_masked
Sm100BlockScaledPersistentDenseGemmKernel
4 changes: 4 additions & 0 deletions docs/api/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ Kernels for normalization layers.
:toctree: ../generated

rmsnorm
rmsnorm_quant
fused_add_rmsnorm
fused_add_rmsnorm_quant
gemma_rmsnorm
gemma_fused_add_rmsnorm
layernorm
rmsnorm_fp4quant
add_rmsnorm_fp4quant
13 changes: 13 additions & 0 deletions docs/api/quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ Quantization related kernels.

packbits
segment_packbits

flashinfer.fp8_quantization
===========================

FP8 Quantization kernels.

.. currentmodule:: flashinfer.fp8_quantization

.. autosummary::
:toctree: ../generated

mxfp8_quantize
mxfp8_dequantize_host
9 changes: 9 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100

# CuTe-DSL GEMM kernels (conditionally available)
try:
from .gemm import grouped_gemm_nt_masked as grouped_gemm_nt_masked
from .gemm import (
Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel,
)
except ImportError:
pass # CuTe-DSL not available
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
Expand Down
8 changes: 8 additions & 0 deletions flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,11 @@ def silu_and_mul_scaled_nvfp4_experts_quantize(
a_global_sf,
)
return a_fp4, a_sf


__all__ = [
"gelu_and_mul",
"gelu_tanh_and_mul",
"silu_and_mul",
"silu_and_mul_scaled_nvfp4_experts_quantize",
]
5 changes: 5 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,5 +881,10 @@ def main():
)


__all__ = [
"register_default_modules",
]


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions flashinfer/api_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,8 @@ def wrapper(*args, **kwargs):
if func is None:
return decorator
return decorator(func)


__all__ = [
"flashinfer_api",
]
15 changes: 15 additions & 0 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,18 @@ def clear_cubin():
shutil.rmtree(FLASHINFER_CUBIN_DIR)
else:
logger.info(f"Cubin directory does not exist: {FLASHINFER_CUBIN_DIR}")


__all__ = [
# Classes
"ArtifactPath",
"CheckSumHash",
# Functions
"temp_env_var",
"get_available_cubin_files",
"get_checksums",
"get_subdir_file_list",
"download_artifacts",
"get_artifacts_status",
"clear_cubin",
]
6 changes: 6 additions & 0 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,9 @@ def __init__(
jit_args=jit_args,
jit_kwargs=jit_kwargs,
)


__all__ = [
"BatchAttention",
"BatchAttentionWithAttentionSinkWrapper",
]
5 changes: 5 additions & 0 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,8 @@ def clear_cache(self) -> None:
def reset_statistics(self) -> None:
"""Reset all statistics counters."""
self.stats = AutoTunerStatistics()


__all__ = [
"autotune",
]
10 changes: 10 additions & 0 deletions flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,3 +1078,13 @@ def forward(
def end_forward(self) -> None:
r"""Warning: this function is deprecated and has no effect"""
pass


__all__ = [
"BatchDecodeWithSharedPrefixPagedKVCacheWrapper",
"BatchPrefillWithSharedPrefixPagedKVCacheWrapper",
"MultiLevelCascadeAttentionWrapper",
"merge_state",
"merge_state_in_place",
"merge_states",
]
55 changes: 55 additions & 0 deletions flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,58 @@
)

# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo

__all__ = [
# CUDA IPC Utilities
"CudaRTLibrary",
"create_shared_buffer",
"free_shared_buffer",
# DLPack Utilities
"pack_strided_memory",
# Mapping Utilities
"Mapping",
# TensorRT-LLM AllReduce - Types and Enums
"AllReduceFusionOp",
"AllReduceFusionPattern",
"AllReduceStrategyConfig",
"AllReduceStrategyType",
"QuantizationSFLayout",
# TensorRT-LLM AllReduce - Core Operations
"trtllm_allreduce_fusion",
"trtllm_custom_all_reduce",
"trtllm_moe_allreduce_fusion",
"trtllm_moe_finalize_allreduce_fusion",
# TensorRT-LLM AllReduce - Workspace Management
"trtllm_create_ipc_workspace_for_all_reduce",
"trtllm_create_ipc_workspace_for_all_reduce_fusion",
"trtllm_destroy_ipc_workspace_for_all_reduce",
"trtllm_destroy_ipc_workspace_for_all_reduce_fusion",
# TensorRT-LLM AllReduce - Initialization and Utilities
"trtllm_lamport_initialize",
"trtllm_lamport_initialize_all",
"compute_fp4_swizzled_layout_sf_size",
"gen_trtllm_comm_module",
# vLLM AllReduce
"vllm_all_reduce",
"vllm_dispose",
"gen_vllm_comm_module",
"vllm_get_graph_buffer_ipc_meta",
"vllm_init_custom_ar",
"vllm_meta_size",
"vllm_register_buffer",
"vllm_register_graph_buffers",
# Unified AllReduce Fusion API
"AllReduceFusionWorkspace",
"TRTLLMAllReduceFusionWorkspace",
"MNNVLAllReduceFusionWorkspace",
"allreduce_fusion",
"create_allreduce_fusion_workspace",
# MNNVL A2A (Throughput Backend)
"MoeAlltoAll",
"moe_a2a_combine",
"moe_a2a_dispatch",
"moe_a2a_initialize",
"moe_a2a_get_workspace_size_per_rank",
"moe_a2a_sanitize_expert_ids",
"moe_a2a_wrap_payload_tensor_in_workspace",
]
5 changes: 5 additions & 0 deletions flashinfer/compilation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ def get_nvcc_flags_list(
f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"
for major, minor in supported_cuda_archs
] + self.COMMON_NVCC_FLAGS


__all__ = [
"CompilationContext",
]
6 changes: 6 additions & 0 deletions flashinfer/concat_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,9 @@ def concat_mla_k(
- ``rope_dim = 64``
"""
get_concat_mla_module().concat_mla_k(k, k_nope, k_rope)


__all__ = [
"get_concat_mla_module",
"concat_mla_k",
]
5 changes: 5 additions & 0 deletions flashinfer/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ def checkCudaErrors(result):
return result[1]
else:
return result[1:]


__all__ = [
"checkCudaErrors",
]
10 changes: 10 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2677,3 +2677,13 @@ def fast_decode_plan(
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta


__all__ = [
"BatchDecodeMlaWithPagedKVCacheWrapper",
"BatchDecodeWithPagedKVCacheWrapper",
"CUDAGraphBatchDecodeWithPagedKVCacheWrapper",
"cudnn_batch_decode_with_kv_cache",
"fast_decode_plan",
"single_decode_with_kv_cache",
]
11 changes: 11 additions & 0 deletions flashinfer/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,3 +1609,14 @@ def __getitem__(self, key):


KERNEL_MAP = KernelMap()


__all__ = [
# Classes
"KernelMap",
# Functions
"load",
"load_all",
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_gemm_nt_masked",
]
17 changes: 17 additions & 0 deletions flashinfer/fp4_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,3 +999,20 @@ def scaled_fp4_grouped_quantize(
mask,
)
return a_fp4, a_sf


__all__ = [
"SfLayout",
"block_scale_interleave",
"nvfp4_block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float",
"fp4_quantize",
"mxfp4_dequantize_host",
"mxfp4_dequantize",
"mxfp4_quantize",
"nvfp4_quantize",
"nvfp4_batched_quantize",
"shuffle_matrix_a",
"shuffle_matrix_sf_a",
"scaled_fp4_grouped_quantize",
]
6 changes: 6 additions & 0 deletions flashinfer/fp8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,9 @@ def mxfp8_dequantize_host(
scale_tensor,
is_sf_swizzled_layout,
)


__all__ = [
"mxfp8_dequantize_host",
"mxfp8_quantize",
]
Loading