From 43a2ae7958d356c57fe81e59d5df89b3e3843cf7 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Wed, 22 Oct 2025 20:53:16 +0000 Subject: [PATCH 1/2] metainfohash and docstring --- flashinfer/artifacts.py | 23 ++++++++++++++--------- flashinfer/jit/attention/modules.py | 14 ++++++++++---- flashinfer/jit/cubin_loader.py | 17 +++++++---------- flashinfer/jit/fused_moe.py | 12 +++++++++--- flashinfer/jit/gemm/core.py | 20 ++++++++++++++++---- 5 files changed, 56 insertions(+), 30 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 89b458e8d0..6ce8cabab3 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -79,7 +79,13 @@ def get_available_cubin_files( return tuple() +@dataclass(frozen=True) class ArtifactPath: + """ + This class is used to store the paths of the cubin files in artifactory. + The paths are generated in cubin publishing script logs (accessible by codeowners). + When compiling new cubins for backend directories, update the corresponding path. + """ TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" @@ -93,19 +99,18 @@ class ArtifactPath: @dataclass(frozen=True) class MetaInfoHash: + """ + Encode sha256 hash of kernel_map.json for DEEPGEMM + """ DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" - TRTLLM_GEN_FMHA: str = ( - "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" - ) - TRTLLM_GEN_BMM: str = ( - "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" - ) - TRTLLM_GEN_GEMM: str = ( - "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" - ) class CheckSumHash: + """ + This class is used to store the checksums of the cubin files in artifactory. + The sha256 hashes are generated in cubin publishing script logs (accessible by codeowners). + When updating the ArtifactPath for backend directories, update the corresponding hash. + """ TRTLLM_GEN_FMHA: str = ( "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 14bf3c2e42..0ccc36e1c4 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -28,7 +28,7 @@ sm90a_nvcc_flags, current_compilation_context, ) -from ...jit.cubin_loader import get_cubin +from ...jit.cubin_loader import get_cubin, get_meta_hash, from ..utils import ( dtype_map, filename_safe_dtype_map, @@ -1568,15 +1568,21 @@ def gen_fmha_cutlass_sm100a_module( def gen_trtllm_gen_fmha_module(): - from ...artifacts import ArtifactPath, MetaInfoHash + from ...artifacts import ArtifactPath, CheckSumHash include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" + # Check if checksums.txt exists in the cubin directory + checksum_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt" + checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_FMHA) + assert checksum, f"Failed to get checksums.txt from {checksum_path}" + + meta_hash = get_meta_hash(checksum) # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_FMHA, + meta_hash, ) # make sure "flashinferMetaInfo.h" is downloaded or cached @@ -1592,7 +1598,7 @@ def gen_trtllm_gen_fmha_module(): extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], extra_cuda_cflags=[ f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"', - f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"', + f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{meta_hash}\\"', ], ) diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index a5a9bf6253..f0ce8af175 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -136,18 +136,15 @@ def download_file( return False -def get_meta_hash(checksum_path: str) -> str: +def get_meta_hash(checksums_bytes: bytes) -> str: """ - Load the file from local cache (checksums.txt) - and get the hash of corresponding flashinferMetaInfo.h file + Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file """ - local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt") - with open(local_path, "r") as f: - for line in f: - sha256, filename = line.strip().split() - if ".h" in filename: - return sha256 - raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}") + for line in checksums_bytes.splitlines(): + sha256, filename = line.strip().split() + if ".h" in filename: + return sha256 + raise ValueError(f"Invalid checksums.txt, no flashinferMetaInfo.h found") def verify_cubin(cubin_path: str, expected_sha256: str) -> bool: diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index e4c7515a4e..6216bb78b0 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -17,10 +17,10 @@ from typing import List from . import env as jit_env -from ..artifacts import ArtifactPath, MetaInfoHash +from ..artifacts import ArtifactPath, CheckSumHash from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags from .cpp_ext import is_cuda_version_at_least -from .cubin_loader import get_cubin +from .cubin_loader import get_cubin, get_meta_hash from .gemm.cutlass.generate_kernels import generate_gemm_operations @@ -180,10 +180,16 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include" header_name = "flashinferMetaInfo" + # Check if checksums.txt exists in the cubin directory + checksum_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt" + checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_BMM) + assert checksum, f"Failed to get checksums.txt from {checksum_path}" + meta_hash = get_meta_hash(checksum) + # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_BMM, + meta_hash, ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index a65d1873b5..6564aefa35 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -20,7 +20,7 @@ import jinja2 import torch -from ...artifacts import ArtifactPath, MetaInfoHash +from ...artifacts import ArtifactPath, CheckSumHash from .. import env as jit_env from ..core import ( JitSpec, @@ -30,7 +30,7 @@ sm100f_nvcc_flags, current_compilation_context, ) -from ..cubin_loader import get_cubin +from ..cubin_loader import get_cubin, get_meta_hash from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different @@ -361,10 +361,16 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include" header_name = "flashinferMetaInfo" + # Check if checksums.txt exists in the cubin directory + checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt" + checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM) + assert checksum, f"Failed to get checksums.txt from {checksum_path}" + meta_hash = get_meta_hash(checksum) + # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_GEMM, + meta_hash, ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" @@ -505,10 +511,16 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec: include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include" header_name = "flashinferMetaInfo" + # Check if checksums.txt exists in the cubin directory + checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt" + checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM) + assert checksum, f"Failed to get checksums.txt from {checksum_path}" + meta_hash = get_meta_hash(checksum) + # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h", - MetaInfoHash.TRTLLM_GEN_GEMM, + meta_hash, ) # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" From b3fcb1ca1a4a6fe81cfecb96ceb310074584f466 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Wed, 22 Oct 2025 21:16:33 +0000 Subject: [PATCH 2/2] precommit --- flashinfer/artifacts.py | 2 ++ flashinfer/jit/attention/modules.py | 2 +- flashinfer/jit/cubin_loader.py | 5 +++-- flashinfer/jit/fused_moe.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index d621f92c76..25f679968f 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -86,6 +86,7 @@ class ArtifactPath: The paths are generated in cubin publishing script logs (accessible by codeowners). When compiling new cubins for backend directories, update the corresponding path. """ + TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" @@ -117,6 +118,7 @@ class CheckSumHash: The sha256 hashes are generated in cubin publishing script logs (accessible by codeowners). When updating the ArtifactPath for backend directories, update the corresponding hash. """ + TRTLLM_GEN_FMHA: str = ( "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 0ccc36e1c4..475acdcd1e 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -28,7 +28,7 @@ sm90a_nvcc_flags, current_compilation_context, ) -from ...jit.cubin_loader import get_cubin, get_meta_hash, +from ...jit.cubin_loader import get_cubin, get_meta_hash from ..utils import ( dtype_map, filename_safe_dtype_map, diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index f0ce8af175..1aae47722a 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -140,11 +140,12 @@ def get_meta_hash(checksums_bytes: bytes) -> str: """ Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file """ - for line in checksums_bytes.splitlines(): + checksums_lines = checksums_bytes.decode("utf-8").splitlines() + for line in checksums_lines: sha256, filename = line.strip().split() if ".h" in filename: return sha256 - raise ValueError(f"Invalid checksums.txt, no flashinferMetaInfo.h found") + raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found") def verify_cubin(cubin_path: str, expected_sha256: str) -> bool: diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 6216bb78b0..f0f781ad05 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -185,7 +185,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_BMM) assert checksum, f"Failed to get checksums.txt from {checksum_path}" meta_hash = get_meta_hash(checksum) - + # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}.h",