Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def get_available_cubin_files(
return tuple()


@dataclass(frozen=True)
class ArtifactPath:
"""
This class is used to store the paths of the cubin files in artifactory.
The paths are generated in cubin publishing script logs (accessible by codeowners).
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
Expand All @@ -106,6 +113,12 @@ class MetaInfoHash:


class CheckSumHash:
"""
This class is used to store the checksums of the cubin files in artifactory.
The sha256 hashes are generated in cubin publishing script logs (accessible by codeowners).
When updating the ArtifactPath for backend directories, update the corresponding hash.
"""

TRTLLM_GEN_FMHA: str = (
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
)
Expand Down
14 changes: 10 additions & 4 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sm90a_nvcc_flags,
current_compilation_context,
)
from ...jit.cubin_loader import get_cubin
from ...jit.cubin_loader import get_cubin, get_meta_hash
from ..utils import (
dtype_map,
filename_safe_dtype_map,
Expand Down Expand Up @@ -1568,15 +1568,21 @@ def gen_fmha_cutlass_sm100a_module(


def gen_trtllm_gen_fmha_module():
from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath, CheckSumHash

include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include"
header_name = "flashInferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_FMHA)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"

meta_hash = get_meta_hash(checksum)
# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_FMHA,
meta_hash,
)

# make sure "flashinferMetaInfo.h" is downloaded or cached
Expand All @@ -1592,7 +1598,7 @@ def gen_trtllm_gen_fmha_module():
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
extra_cuda_cflags=[
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{meta_hash}\\"',
],
)

Expand Down
18 changes: 8 additions & 10 deletions flashinfer/jit/cubin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,16 @@ def download_file(
return False


def get_meta_hash(checksum_path: str) -> str:
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Load the file from local cache (checksums.txt)
and get the hash of corresponding flashinferMetaInfo.h file
Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
"""
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
with open(local_path, "r") as f:
for line in f:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
checksums_lines = checksums_bytes.decode("utf-8").splitlines()
for line in checksums_lines:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
Comment on lines +139 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

get_meta_hash: match MetaInfo header precisely; handle decoding/split robustly.

Current logic returns the first ".h" sha, which can select the wrong file and break header fetch. Also may ValueError on malformed lines and assume UTF‑8 without fallback.

Apply:

-def get_meta_hash(checksums_bytes: bytes) -> str:
-    """
-    Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
-    """
-    checksums_lines = checksums_bytes.decode("utf-8").splitlines()
-    for line in checksums_lines:
-        sha256, filename = line.strip().split()
-        if ".h" in filename:
-            return sha256
-    raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
+def get_meta_hash(checksums_bytes: bytes) -> str:
+    """
+    Parse checksums.txt content (bytes) and return the sha256 for the MetaInfo header.
+    """
+    text = checksums_bytes.decode("utf-8", errors="replace")
+    for line in text.splitlines():
+        parts = line.strip().split(maxsplit=1)
+        if len(parts) != 2:
+            continue
+        sha256, filename = parts  # both str
+        if filename.lower().endswith("metainfo.h"):
+            return sha256
+    raise ValueError("No MetaInfo header entry in checksums.txt")
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Load the file from local cache (checksums.txt)
and get the hash of corresponding flashinferMetaInfo.h file
Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
"""
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
with open(local_path, "r") as f:
for line in f:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
checksums_lines = checksums_bytes.decode("utf-8").splitlines()
for line in checksums_lines:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Parse checksums.txt content (bytes) and return the sha256 for the MetaInfo header.
"""
text = checksums_bytes.decode("utf-8", errors="replace")
for line in text.splitlines():
parts = line.strip().split(maxsplit=1)
if len(parts) != 2:
continue
sha256, filename = parts # both str
if filename.lower().endswith("metainfo.h"):
return sha256
raise ValueError("No MetaInfo header entry in checksums.txt")
🧰 Tools
πŸͺ› Ruff (0.14.1)

148-148: Avoid specifying long messages outside the exception class

(TRY003)

πŸ€– Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 139-148, the parser currently
returns the first ".h" entry and can crash on malformed lines or non-UTF8 bytes;
change it to decode checksums_bytes with errors='replace' (or 'ignore'),
normalize/remove BOM, iterate each line and split with maxsplit=1 to avoid
ValueErrors on extra whitespace, strip surrounding quotes from the filename and
use the basename, then only return the checksum when the filename exactly
matches "flashinferMetaInfo.h" (also validate checksum looks like a 64-char hex
string); skip malformed lines and raise a clear ValueError if no exact match is
found.



def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:
Expand Down
12 changes: 9 additions & 3 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from typing import List

from . import env as jit_env
from ..artifacts import ArtifactPath, MetaInfoHash
from ..artifacts import ArtifactPath, CheckSumHash
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations


Expand Down Expand Up @@ -180,10 +180,16 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_BMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_BMM,
meta_hash,
)
Comment on lines +183 to 193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ› οΈ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Checksum β†’ MetaInfo flow is correct; ensure name parity across modules.

Logic is good, but depends on get_meta_hash correctly matching MetaInfo and consistent header_name across files. After fixing get_meta_hash and aligning header_name in attention, this path should be stable.

To verify naming consistency across the repo:


🏁 Script executed:

#!/bin/bash
# Find all header_name assignments and their values
rg -n -C1 -P 'header_name\s*=\s*["\']flash(?:infer|Infer)MetaInfo["\']'

Length of output: 158


🏁 Script executed:

#!/bin/bash
# Search for header_name assignments across Python files
rg -n -C2 -e 'header_name\s*=\s*' --type py

# Search for MetaInfo occurrences across Python files
rg -n 'MetaInfo' --type py

Length of output: 4624


Unify header_name casing across modules

  • In flashinfer/jit/attention/modules.py (e.g. lines 1574, 1588), change "flashInferMetaInfo" to "flashinferMetaInfo" to match fused_moe.py and gemm/core.py. Update any artifacts.py entries and tests to use the same casing.

# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down
20 changes: 16 additions & 4 deletions flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jinja2
import torch

from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath, CheckSumHash
from .. import env as jit_env
from ..core import (
JitSpec,
Expand All @@ -30,7 +30,7 @@
sm100f_nvcc_flags,
current_compilation_context,
)
from ..cubin_loader import get_cubin
from ..cubin_loader import get_cubin, get_meta_hash
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different


Expand Down Expand Up @@ -361,10 +361,16 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_GEMM,
meta_hash,
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down Expand Up @@ -505,10 +511,16 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_GEMM,
meta_hash,
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down