diff --git a/docs/api/attention.rst b/docs/api/attention.rst index eff9160787..4aec5f4164 100644 --- a/docs/api/attention.rst +++ b/docs/api/attention.rst @@ -25,6 +25,7 @@ Batch Decoding cudnn_batch_decode_with_kv_cache trtllm_batch_decode_with_kv_cache + fast_decode_plan .. autoclass:: BatchDecodeWithPagedKVCacheWrapper :members: @@ -108,3 +109,21 @@ PageAttention for MLA :members: .. automethod:: __init__ + + +flashinfer.pod +============== + +POD (Prefix-Only Decode) attention kernels for efficient prefix caching. + +.. currentmodule:: flashinfer.pod + +.. autoclass:: PODWithPagedKVCacheWrapper + :members: + + .. automethod:: __init__ + +.. autoclass:: BatchPODWithPagedKVCacheWrapper + :members: + + .. automethod:: __init__ diff --git a/docs/api/comm.rst b/docs/api/comm.rst index 32e4b4ee4d..1144ba84e9 100644 --- a/docs/api/comm.rst +++ b/docs/api/comm.rst @@ -46,7 +46,7 @@ Types and Enums AllReduceFusionPattern AllReduceStrategyConfig AllReduceStrategyType - FP4QuantizationSFLayout + QuantizationSFLayout Core Operations ~~~~~~~~~~~~~~~ @@ -94,6 +94,26 @@ vLLM AllReduce vllm_get_graph_buffer_ipc_meta vllm_meta_size +Unified AllReduce Fusion API +----------------------------- + +.. autosummary:: + :toctree: ../generated + + AllReduceFusionWorkspace + TRTLLMAllReduceFusionWorkspace + allreduce_fusion + create_allreduce_fusion_workspace + +.. currentmodule:: flashinfer.comm.trtllm_mnnvl_ar + +.. autosummary:: + :toctree: ../generated + + MNNVLAllReduceFusionWorkspace + +.. currentmodule:: flashinfer.comm + MNNVL (Multi-Node NVLink) ------------------------- diff --git a/docs/api/gemm.rst b/docs/api/gemm.rst index 8c9fbeeea6..00141ae608 100644 --- a/docs/api/gemm.rst +++ b/docs/api/gemm.rst @@ -22,6 +22,7 @@ FP8 GEMM :toctree: ../generated bmm_fp8 + mm_fp8 gemm_fp8_nt_groupwise group_gemm_fp8_nt_groupwise group_deepgemm_fp8_nt_groupwise @@ -43,3 +44,32 @@ Grouped GEMM (Ampere/Hopper) :exclude-members: forward .. automethod:: __init__ + +Blackwell GEMM +-------------- + +.. autosummary:: + :toctree: ../generated + + tgv_gemm_sm100 + +TensorRT-LLM Low Latency GEMM +------------------------------ + +.. currentmodule:: flashinfer.trtllm_low_latency_gemm + +.. autosummary:: + :toctree: ../generated + + prepare_low_latency_gemm_weights + +CuTe-DSL GEMM +------------- + +.. currentmodule:: flashinfer.gemm + +.. autosummary:: + :toctree: ../generated + + grouped_gemm_nt_masked + Sm100BlockScaledPersistentDenseGemmKernel diff --git a/docs/api/norm.rst b/docs/api/norm.rst index 98c0d4b5fa..7c692c93be 100644 --- a/docs/api/norm.rst +++ b/docs/api/norm.rst @@ -11,7 +11,11 @@ Kernels for normalization layers. :toctree: ../generated rmsnorm + rmsnorm_quant fused_add_rmsnorm + fused_add_rmsnorm_quant gemma_rmsnorm gemma_fused_add_rmsnorm layernorm + rmsnorm_fp4quant + add_rmsnorm_fp4quant diff --git a/docs/api/quantization.rst b/docs/api/quantization.rst index c7461aeefd..64d0294550 100644 --- a/docs/api/quantization.rst +++ b/docs/api/quantization.rst @@ -12,3 +12,16 @@ Quantization related kernels. packbits segment_packbits + +flashinfer.fp8_quantization +=========================== + +FP8 Quantization kernels. + +.. currentmodule:: flashinfer.fp8_quantization + +.. autosummary:: + :toctree: ../generated + + mxfp8_quantize + mxfp8_dequantize_host diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index c2abfd8e2e..35d9d90324 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -89,6 +89,15 @@ from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 + +# CuTe-DSL GEMM kernels (conditionally available) +try: + from .gemm import grouped_gemm_nt_masked as grouped_gemm_nt_masked + from .gemm import ( + Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel, + ) +except ImportError: + pass # CuTe-DSL not available from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper from .norm import fused_add_rmsnorm as fused_add_rmsnorm from .norm import layernorm as layernorm diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 35abb2fdba..893b8fe046 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -222,3 +222,11 @@ def silu_and_mul_scaled_nvfp4_experts_quantize( a_global_sf, ) return a_fp4, a_sf + + +__all__ = [ + "gelu_and_mul", + "gelu_tanh_and_mul", + "silu_and_mul", + "silu_and_mul_scaled_nvfp4_experts_quantize", +] diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 34096af940..46a71d5fb4 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -881,5 +881,10 @@ def main(): ) +__all__ = [ + "register_default_modules", +] + + if __name__ == "__main__": main() diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 734d6bae28..6e95eb8e0d 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -563,3 +563,8 @@ def wrapper(*args, **kwargs): if func is None: return decorator return decorator(func) + + +__all__ = [ + "flashinfer_api", +] diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 717524bc9e..8a5a813591 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -248,3 +248,18 @@ def clear_cubin(): shutil.rmtree(FLASHINFER_CUBIN_DIR) else: logger.info(f"Cubin directory does not exist: {FLASHINFER_CUBIN_DIR}") + + +__all__ = [ + # Classes + "ArtifactPath", + "CheckSumHash", + # Functions + "temp_env_var", + "get_available_cubin_files", + "get_checksums", + "get_subdir_file_list", + "download_artifacts", + "get_artifacts_status", + "clear_cubin", +] diff --git a/flashinfer/attention.py b/flashinfer/attention.py index c4bc4f27dc..087afe32be 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -277,3 +277,9 @@ def __init__( jit_args=jit_args, jit_kwargs=jit_kwargs, ) + + +__all__ = [ + "BatchAttention", + "BatchAttentionWithAttentionSinkWrapper", +] diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index a81c8f2546..b0d066a0e4 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -789,3 +789,8 @@ def clear_cache(self) -> None: def reset_statistics(self) -> None: """Reset all statistics counters.""" self.stats = AutoTunerStatistics() + + +__all__ = [ + "autotune", +] diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 1de363bb37..7c1251b318 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -1078,3 +1078,13 @@ def forward( def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect""" pass + + +__all__ = [ + "BatchDecodeWithSharedPrefixPagedKVCacheWrapper", + "BatchPrefillWithSharedPrefixPagedKVCacheWrapper", + "MultiLevelCascadeAttentionWrapper", + "merge_state", + "merge_state_in_place", + "merge_states", +] diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 5f186002dc..ae46375f4d 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -66,3 +66,58 @@ ) # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo + +__all__ = [ + # CUDA IPC Utilities + "CudaRTLibrary", + "create_shared_buffer", + "free_shared_buffer", + # DLPack Utilities + "pack_strided_memory", + # Mapping Utilities + "Mapping", + # TensorRT-LLM AllReduce - Types and Enums + "AllReduceFusionOp", + "AllReduceFusionPattern", + "AllReduceStrategyConfig", + "AllReduceStrategyType", + "QuantizationSFLayout", + # TensorRT-LLM AllReduce - Core Operations + "trtllm_allreduce_fusion", + "trtllm_custom_all_reduce", + "trtllm_moe_allreduce_fusion", + "trtllm_moe_finalize_allreduce_fusion", + # TensorRT-LLM AllReduce - Workspace Management + "trtllm_create_ipc_workspace_for_all_reduce", + "trtllm_create_ipc_workspace_for_all_reduce_fusion", + "trtllm_destroy_ipc_workspace_for_all_reduce", + "trtllm_destroy_ipc_workspace_for_all_reduce_fusion", + # TensorRT-LLM AllReduce - Initialization and Utilities + "trtllm_lamport_initialize", + "trtllm_lamport_initialize_all", + "compute_fp4_swizzled_layout_sf_size", + "gen_trtllm_comm_module", + # vLLM AllReduce + "vllm_all_reduce", + "vllm_dispose", + "gen_vllm_comm_module", + "vllm_get_graph_buffer_ipc_meta", + "vllm_init_custom_ar", + "vllm_meta_size", + "vllm_register_buffer", + "vllm_register_graph_buffers", + # Unified AllReduce Fusion API + "AllReduceFusionWorkspace", + "TRTLLMAllReduceFusionWorkspace", + "MNNVLAllReduceFusionWorkspace", + "allreduce_fusion", + "create_allreduce_fusion_workspace", + # MNNVL A2A (Throughput Backend) + "MoeAlltoAll", + "moe_a2a_combine", + "moe_a2a_dispatch", + "moe_a2a_initialize", + "moe_a2a_get_workspace_size_per_rank", + "moe_a2a_sanitize_expert_ids", + "moe_a2a_wrap_payload_tensor_in_workspace", +] diff --git a/flashinfer/compilation_context.py b/flashinfer/compilation_context.py index dc0d20a584..478c97312b 100644 --- a/flashinfer/compilation_context.py +++ b/flashinfer/compilation_context.py @@ -66,3 +66,8 @@ def get_nvcc_flags_list( f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}" for major, minor in supported_cuda_archs ] + self.COMMON_NVCC_FLAGS + + +__all__ = [ + "CompilationContext", +] diff --git a/flashinfer/concat_ops.py b/flashinfer/concat_ops.py index 8957092a22..e1b825278b 100644 --- a/flashinfer/concat_ops.py +++ b/flashinfer/concat_ops.py @@ -80,3 +80,9 @@ def concat_mla_k( - ``rope_dim = 64`` """ get_concat_mla_module().concat_mla_k(k, k_nope, k_rope) + + +__all__ = [ + "get_concat_mla_module", + "concat_mla_k", +] diff --git a/flashinfer/cuda_utils.py b/flashinfer/cuda_utils.py index 4d97e8675d..7ef555ab6d 100644 --- a/flashinfer/cuda_utils.py +++ b/flashinfer/cuda_utils.py @@ -59,3 +59,8 @@ def checkCudaErrors(result): return result[1] else: return result[1:] + + +__all__ = [ + "checkCudaErrors", +] diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 63c011e148..2967f149fb 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2677,3 +2677,13 @@ def fast_decode_plan( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +__all__ = [ + "BatchDecodeMlaWithPagedKVCacheWrapper", + "BatchDecodeWithPagedKVCacheWrapper", + "CUDAGraphBatchDecodeWithPagedKVCacheWrapper", + "cudnn_batch_decode_with_kv_cache", + "fast_decode_plan", + "single_decode_with_kv_cache", +] diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 18e90d68d0..a8332a1d67 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -1609,3 +1609,14 @@ def __getitem__(self, key): KERNEL_MAP = KernelMap() + + +__all__ = [ + # Classes + "KernelMap", + # Functions + "load", + "load_all", + "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_fp8_gemm_nt_masked", +] diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 7a2e0bde6f..3335be260a 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -999,3 +999,20 @@ def scaled_fp4_grouped_quantize( mask, ) return a_fp4, a_sf + + +__all__ = [ + "SfLayout", + "block_scale_interleave", + "nvfp4_block_scale_interleave", + "e2m1_and_ufp8sf_scale_to_float", + "fp4_quantize", + "mxfp4_dequantize_host", + "mxfp4_dequantize", + "mxfp4_quantize", + "nvfp4_quantize", + "nvfp4_batched_quantize", + "shuffle_matrix_a", + "shuffle_matrix_sf_a", + "scaled_fp4_grouped_quantize", +] diff --git a/flashinfer/fp8_quantization.py b/flashinfer/fp8_quantization.py index 1d2cdeea76..78f0e4ca99 100644 --- a/flashinfer/fp8_quantization.py +++ b/flashinfer/fp8_quantization.py @@ -206,3 +206,9 @@ def mxfp8_dequantize_host( scale_tensor, is_sf_swizzled_layout, ) + + +__all__ = [ + "mxfp8_dequantize_host", + "mxfp8_quantize", +] diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index ed66b0bd9c..5fe2035ceb 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -19,6 +19,17 @@ mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, ) +# CuTe-DSL GEMM kernels +try: + from ..cute_dsl import grouped_gemm_nt_masked as grouped_gemm_nt_masked + from ..cute_dsl import ( + Sm100BlockScaledPersistentDenseGemmKernel as Sm100BlockScaledPersistentDenseGemmKernel, + ) + + _CUTE_DSL_AVAILABLE = True +except ImportError: + _CUTE_DSL_AVAILABLE = False + __all__ = [ "SegmentGEMMWrapper", "bmm_fp8", @@ -34,3 +45,9 @@ "fp8_blockscale_gemm_sm90", "mm_M1_16_K7168_N256", ] + +if _CUTE_DSL_AVAILABLE: + __all__ += [ + "grouped_gemm_nt_masked", + "Sm100BlockScaledPersistentDenseGemmKernel", + ] diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index 09962fd467..eaac54efbb 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -293,3 +293,15 @@ def split_device_green_ctx_by_sm_count( f"Please reduce the requested SM counts or use fewer partitions." ) from e raise + + +__all__ = [ + "get_sm_count_constraint", + "get_cudevice", + "get_device_resource", + "split_resource", + "split_resource_by_sm_count", + "create_green_ctx_streams", + "split_device_green_ctx", + "split_device_green_ctx_by_sm_count", +] diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 3d76524e62..a2901a1209 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -92,3 +92,67 @@ ) if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL) + +__all__ = [ + # Submodules + "cubin_loader", + "env", + # Activation + "gen_act_and_mul_module", + "get_act_and_mul_cu_str", + # Attention + "gen_cudnn_fmha_module", + "gen_batch_attention_module", + "gen_batch_decode_mla_module", + "gen_batch_decode_module", + "gen_batch_mla_module", + "gen_batch_prefill_module", + "gen_customize_batch_decode_module", + "gen_customize_batch_prefill_module", + "gen_customize_single_decode_module", + "gen_customize_single_prefill_module", + "gen_fmha_cutlass_sm100a_module", + "gen_batch_pod_module", + "gen_pod_module", + "gen_single_decode_module", + "gen_single_prefill_module", + "get_batch_attention_uri", + "get_batch_decode_mla_uri", + "get_batch_decode_uri", + "get_batch_mla_uri", + "get_batch_prefill_uri", + "get_pod_uri", + "get_single_decode_uri", + "get_single_prefill_uri", + "gen_trtllm_gen_fmha_module", + "get_trtllm_fmha_v2_module", + # Core + "JitSpec", + "JitSpecStatus", + "JitSpecRegistry", + "jit_spec_registry", + "build_jit_specs", + "clear_cache_dir", + "gen_jit_spec", + "MissingJITCacheError", + "sm90a_nvcc_flags", + "sm100a_nvcc_flags", + "sm100f_nvcc_flags", + "sm103a_nvcc_flags", + "sm110a_nvcc_flags", + "sm120a_nvcc_flags", + "sm121a_nvcc_flags", + "current_compilation_context", + # Cubin Loader + "setup_cubin_loader", + # Comm + "gen_comm_alltoall_module", + "gen_trtllm_mnnvl_comm_module", + "gen_trtllm_comm_module", + "gen_vllm_comm_module", + "gen_nvshmem_module", + "gen_moe_alltoall_module", + # DSv3 Optimizations + "gen_dsv3_router_gemm_module", + "gen_dsv3_fused_routing_module", +] diff --git a/flashinfer/logits_processor/__init__.py b/flashinfer/logits_processor/__init__.py index 80e611189a..21e32e1503 100644 --- a/flashinfer/logits_processor/__init__.py +++ b/flashinfer/logits_processor/__init__.py @@ -32,3 +32,31 @@ from .processors import TopP as TopP from .types import TaggedTensor as TaggedTensor from .types import TensorType as TensorType + +__all__ = [ + # Compiler + "CompileError", + "Compiler", + "compile_pipeline", + # Fusion Rules + "FusionRule", + # Legalization + "LegalizationError", + "legalize_processors", + # Operators + "Op", + "ParameterizedOp", + # Pipeline + "LogitsPipe", + # Processors + "LogitsProcessor", + "MinP", + "Sample", + "Softmax", + "Temperature", + "TopK", + "TopP", + # Types + "TaggedTensor", + "TensorType", +] diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 83415521c3..0e8dcfcd8f 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -797,3 +797,13 @@ def xqa_batch_decode_with_kv_cache_mla( ) return out + + +__all__ = [ + "BatchMLAPagedAttentionWrapper", + "trtllm_batch_decode_with_kv_cache_mla", + "xqa_batch_decode_with_kv_cache_mla", + "get_trtllm_gen_fmha_module", + "get_mla_module", + "get_batch_mla_module", +] diff --git a/flashinfer/norm.py b/flashinfer/norm.py index de27b12d7a..ad3b4819ad 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -411,3 +411,15 @@ def _layernorm_fake( # CuTe-DSL not available rmsnorm_fp4quant = None # type: ignore[misc,assignment] add_rmsnorm_fp4quant = None # type: ignore[misc,assignment] + +__all__ = [ + "add_rmsnorm_fp4quant", + "fused_add_rmsnorm", + "fused_add_rmsnorm_quant", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "layernorm", + "rmsnorm", + "rmsnorm_fp4quant", + "rmsnorm_quant", +] diff --git a/flashinfer/page.py b/flashinfer/page.py index 5a000c3a15..cdfa77c532 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -379,3 +379,11 @@ def append_paged_kv_cache( kv_last_page_len, TensorLayout[kv_layout].value, ) + + +__all__ = [ + "append_paged_kv_cache", + "append_paged_mla_kv_cache", + "get_batch_indices_positions", + "get_seq_lens", +] diff --git a/flashinfer/pod.py b/flashinfer/pod.py index fe2e36c1ef..3b0f43bd4b 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -1199,3 +1199,9 @@ def run( def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass + + +__all__ = [ + "BatchPODWithPagedKVCacheWrapper", + "PODWithPagedKVCacheWrapper", +] diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 463ae1009f..eaa746448b 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3754,3 +3754,11 @@ def fmha_v2_prefill_deepseek( return out, lse else: return out + + +__all__ = [ + "BatchPrefillWithPagedKVCacheWrapper", + "BatchPrefillWithRaggedKVCacheWrapper", + "single_prefill_with_kv_cache", + "single_prefill_with_kv_cache_return_lse", +] diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index 4e279ab5f0..2c75e729b9 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -137,3 +137,9 @@ def segment_packbits( y = torch.empty(output_nnzs, dtype=torch.uint8, device=device) get_quantization_module().segment_packbits(x, indptr, indptr_new, bitorder, y) return y, indptr_new + + +__all__ = [ + "packbits", + "segment_packbits", +] diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 1d069e3189..14c2a49aad 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1669,3 +1669,17 @@ def rope_quantize_fp8_append_paged_kv_cache( ) return q_rope_out, q_nope_out + + +__all__ = [ + "apply_llama31_rope", + "apply_llama31_rope_inplace", + "apply_llama31_rope_pos_ids", + "apply_llama31_rope_pos_ids_inplace", + "apply_rope", + "apply_rope_inplace", + "apply_rope_pos_ids", + "apply_rope_pos_ids_inplace", + "apply_rope_with_cos_sin_cache", + "apply_rope_with_cos_sin_cache_inplace", +] diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 8514da3e15..22f0e9974b 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -1585,3 +1585,19 @@ def chain_speculative_sampling( offset, ) return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num + + +__all__ = [ + "chain_speculative_sampling", + "min_p_sampling_from_probs", + "sampling_from_logits", + "sampling_from_probs", + "softmax", + "top_k_mask_logits", + "top_k_renorm_probs", + "top_k_sampling_from_probs", + "top_k_top_p_sampling_from_logits", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_probs", + "top_p_sampling_from_probs", +] diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 652194ab17..41c256ad5d 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -1166,3 +1166,9 @@ def run( ).contiguous() return (out, lse) if return_lse else out + + +__all__ = [ + "BlockSparseAttentionWrapper", + "VariableBlockSparseAttentionWrapper", +] diff --git a/flashinfer/testing/__init__.py b/flashinfer/testing/__init__.py index 690f550a29..07cedafdb2 100644 --- a/flashinfer/testing/__init__.py +++ b/flashinfer/testing/__init__.py @@ -28,3 +28,18 @@ set_seed, sleep_after_kernel_run, ) + +__all__ = [ + "attention_flops", + "attention_flops_with_actual_seq_lens", + "attention_tb_per_sec", + "attention_tb_per_sec_with_actual_seq_lens", + "attention_tflops_per_sec", + "attention_tflops_per_sec_with_actual_seq_lens", + "bench_gpu_time", + "bench_gpu_time_with_cupti", + "bench_gpu_time_with_cuda_event", + "bench_gpu_time_with_cudagraph", + "set_seed", + "sleep_after_kernel_run", +] diff --git a/flashinfer/tllm_utils.py b/flashinfer/tllm_utils.py index f365779c0b..5c6aa25e9d 100644 --- a/flashinfer/tllm_utils.py +++ b/flashinfer/tllm_utils.py @@ -10,3 +10,9 @@ def get_trtllm_utils_module(): def delay_kernel(stream_delay_micro_secs): get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs) + + +__all__ = [ + "get_trtllm_utils_module", + "delay_kernel", +] diff --git a/flashinfer/topk.py b/flashinfer/topk.py index 4c1c01cf23..7780322b5e 100644 --- a/flashinfer/topk.py +++ b/flashinfer/topk.py @@ -419,3 +419,10 @@ def top_k_ragged_transform( ) return output_indices + + +__all__ = [ + "top_k", + "top_k_page_table_transform", + "top_k_ragged_transform", +] diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py index 2d69bc1e98..988558ffe9 100644 --- a/flashinfer/trtllm_low_latency_gemm.py +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -222,3 +222,8 @@ def prepare_low_latency_gemm_weights( block_k = 128 block_layout_weights = convert_to_block_layout(shuffled_weights, block_k) return block_layout_weights + + +__all__ = [ + "prepare_low_latency_gemm_weights", +] diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 35861b2507..533bf9a056 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1182,3 +1182,76 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +__all__ = [ + # Enums + "PosEncodingMode", + "MaskMode", + "TensorLayout", + "LogLevel", + # Exceptions + "GPUArchitectureError", + "LibraryError", + "BackendSupportedError", + # Classes + "FP4Tensor", + # Constants + "log2e", + "log_level_map", + # Utility functions + "next_positive_power_of_2", + "calculate_tile_tokens_dim", + "is_float8", + "get_indptr", + "get_alibi_slopes", + "canonicalize_torch_dtype", + "get_compute_capability", + "get_gpu_memory_bandwidth", + "ceil_div", + "round_up", + "get_device_sm_count", + "check_shape_dtype_device", + # Backend and version functions + "determine_gemm_backend", + "is_fa3_backend_supported", + "is_cutlass_backend_supported", + "determine_attention_backend", + "version_at_least", + "has_cuda_cudart", + "has_flashinfer_jit_cache", + "has_flashinfer_cubin", + "get_cuda_python_version", + "determine_mla_backend", + # SM support functions + "is_sm90a_supported", + "is_sm100a_supported", + "is_sm100f_supported", + "is_sm110a_supported", + "is_sm120a_supported", + "is_sm121a_supported", + "device_support_pdl", + # Logging + "set_log_level", + # FP4/Shuffle functions + "get_shuffle_block_size", + "get_shuffle_matrix_a_row_indices", + "get_shuffle_matrix_sf_a_row_indices", + "get_native_fp4_dtype", + # Decorators + "supported_compute_capability", + "backend_requirement", + "register_custom_op", + "register_fake_op", + # Internal helpers (used across modules) + "_expand_5d", + "_expand_4d", + "_check_pos_encoding_mode", + "_check_kv_layout", + "_unpack_paged_kv_cache", + "_get_cache_buf", + "_ceil_pow2", + "_get_range_buf", + "_get_cache_alibi_slopes_buf", + "_check_cached_qkv_data_type", +] diff --git a/flashinfer/version.py b/flashinfer/version.py index 95ad245497..2a1ba11b79 100644 --- a/flashinfer/version.py +++ b/flashinfer/version.py @@ -21,3 +21,9 @@ except ModuleNotFoundError: __version__ = "0.0.0+unknown" __git_version__ = "unknown" + + +__all__ = [ + "__version__", + "__git_version__", +] diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 729c79a718..e7612125cd 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -527,3 +527,9 @@ def xqa_mla( workspace_buffer, enable_pdl, ) + + +__all__ = [ + "xqa", + "xqa_mla", +]