Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ def get_available_cubin_files(
return tuple()


@dataclass(frozen=True)
class ArtifactPath:
"""
This class is used to store the paths of the cubin files in artifactory.
The paths are generated in cubin publishing script logs (accessible by codeowners).
When compiling new cubins for backend directories, update the corresponding path.
"""
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
Expand All @@ -106,6 +112,11 @@ class MetaInfoHash:


class CheckSumHash:
"""
This class is used to store the checksums of the cubin files in artifactory.
The sha256 hashes are generated in cubin publishing script logs (accessible by codeowners).
When updating the ArtifactPath for backend directories, update the corresponding hash.
"""
TRTLLM_GEN_FMHA: str = (
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
)
Expand Down
14 changes: 10 additions & 4 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sm90a_nvcc_flags,
current_compilation_context,
)
from ...jit.cubin_loader import get_cubin
from ...jit.cubin_loader import get_cubin, get_meta_hash,
from ..utils import (
dtype_map,
filename_safe_dtype_map,
Expand Down Expand Up @@ -1568,15 +1568,21 @@ def gen_fmha_cutlass_sm100a_module(


def gen_trtllm_gen_fmha_module():
from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath, CheckSumHash

include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include"
header_name = "flashInferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_FMHA)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"

meta_hash = get_meta_hash(checksum)
# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_FMHA,
meta_hash,
)

# make sure "flashinferMetaInfo.h" is downloaded or cached
Expand All @@ -1592,7 +1598,7 @@ def gen_trtllm_gen_fmha_module():
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
extra_cuda_cflags=[
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{meta_hash}\\"',
],
)

Expand Down
17 changes: 7 additions & 10 deletions flashinfer/jit/cubin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,15 @@ def download_file(
return False


def get_meta_hash(checksum_path: str) -> str:
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Load the file from local cache (checksums.txt)
and get the hash of corresponding flashinferMetaInfo.h file
Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
"""
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
with open(local_path, "r") as f:
for line in f:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
for line in checksums_bytes.splitlines():
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid checksums.txt, no flashinferMetaInfo.h found")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Bytes/str bug in get_meta_hash; also lint F541.

Currently iterates bytes, compares str ".h" to bytes, returns bytes; will TypeError/mismatch and fail hash compare. Also f-string has no placeholders.

Apply:

-def get_meta_hash(checksums_bytes: bytes) -> str:
-    """
-    Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
-    """
-    for line in checksums_bytes.splitlines():
-        sha256, filename = line.strip().split()
-        if ".h" in filename:
-            return sha256
-    raise ValueError(f"Invalid checksums.txt, no flashinferMetaInfo.h found")
+def get_meta_hash(checksums_bytes: bytes) -> str:
+    """
+    Parse checksums.txt content (bytes) and return the sha256 for the MetaInfo header.
+    """
+    text = checksums_bytes.decode("utf-8", "replace")
+    for line in text.splitlines():
+        parts = line.strip().split()
+        if len(parts) != 2:
+            continue
+        sha256, filename = parts  # both str
+        name = filename.lower()
+        if name.endswith("metainfo.h"):
+            return sha256
+    raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Load the file from local cache (checksums.txt)
and get the hash of corresponding flashinferMetaInfo.h file
Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
"""
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
with open(local_path, "r") as f:
for line in f:
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
for line in checksums_bytes.splitlines():
sha256, filename = line.strip().split()
if ".h" in filename:
return sha256
raise ValueError(f"Invalid checksums.txt, no flashinferMetaInfo.h found")
def get_meta_hash(checksums_bytes: bytes) -> str:
"""
Parse checksums.txt content (bytes) and return the sha256 for the MetaInfo header.
"""
text = checksums_bytes.decode("utf-8", "replace")
for line in text.splitlines():
parts = line.strip().split()
if len(parts) != 2:
continue
sha256, filename = parts # both str
name = filename.lower()
if name.endswith("metainfo.h"):
return sha256
raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
🧰 Tools
πŸͺ› GitHub Actions: pre-commit

[error] 147-147: F541: f-string without any placeholders. Remove extraneous 'f' prefix.

πŸͺ› Ruff (0.14.1)

147-147: Avoid specifying long messages outside the exception class

(TRY003)


147-147: f-string without any placeholders

Remove extraneous f prefix

(F541)

πŸ€– Prompt for AI Agents
In flashinfer/jit/cubin_loader.py around lines 139 to 148, fix the bytes/str
mismatch and the f-string lint: ensure checksums_bytes is decoded to text (e.g.,
decode to UTF-8 or accept str input), iterate over text lines, split each line
into sha256 and filename (use split(maxsplit=1)), check filename.endswith('.h')
(string comparison), return the sha256 as a str, and replace the f-string in the
raise with a normal string literal to avoid F541.


def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:
Expand Down
12 changes: 9 additions & 3 deletions flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from typing import List

from . import env as jit_env
from ..artifacts import ArtifactPath, MetaInfoHash
from ..artifacts import ArtifactPath, CheckSumHash
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations


Expand Down Expand Up @@ -180,10 +180,16 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_BMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_BMM,
meta_hash,
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down
20 changes: 16 additions & 4 deletions flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jinja2
import torch

from ...artifacts import ArtifactPath, MetaInfoHash
from ...artifacts import ArtifactPath, CheckSumHash
from .. import env as jit_env
from ..core import (
JitSpec,
Expand All @@ -30,7 +30,7 @@
sm100f_nvcc_flags,
current_compilation_context,
)
from ..cubin_loader import get_cubin
from ..cubin_loader import get_cubin, get_meta_hash
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different


Expand Down Expand Up @@ -361,10 +361,16 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_GEMM,
meta_hash,
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down Expand Up @@ -505,10 +511,16 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
header_name = "flashinferMetaInfo"

# Check if checksums.txt exists in the cubin directory
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
meta_hash = get_meta_hash(checksum)

# use `get_cubin` to get "flashinferMetaInfo.h"
metainfo = get_cubin(
f"{include_path}/{header_name}.h",
MetaInfoHash.TRTLLM_GEN_GEMM,
meta_hash,
)
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"
Expand Down
Loading