Skip to content

Commit b73f04c

Browse files
authored
misc: Update artifacts docstring and MetaInfoHash (#1967)
<!-- .github/pull_request_template.md --> ## 📌 Description Amendment to [PR 1761](#1761), appending docstring to two artifactory path classes and deprecating need to update MetaInfoHash by directly accessing the checksum.txt file. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added runtime integrity checks for compiled artifacts that verify and use checksum data during loading to prevent missing or mismatched artifact headers. * **Refactor** * Switched artifact hash resolution to compute hashes dynamically from provided checksums, improving validation, reliability, and resilience when loading precompiled components. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0260ab3 commit b73f04c

File tree

5 files changed

+56
-21
lines changed

5 files changed

+56
-21
lines changed

flashinfer/artifacts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ def get_available_cubin_files(
7979
return tuple()
8080

8181

82+
@dataclass(frozen=True)
8283
class ArtifactPath:
84+
"""
85+
This class is used to store the paths of the cubin files in artifactory.
86+
The paths are generated in cubin publishing script logs (accessible by codeowners).
87+
When compiling new cubins for backend directories, update the corresponding path.
88+
"""
89+
8390
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
8491
TRTLLM_GEN_BMM: str = (
8592
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
@@ -106,6 +113,12 @@ class MetaInfoHash:
106113

107114

108115
class CheckSumHash:
116+
"""
117+
This class is used to store the checksums of the cubin files in artifactory.
118+
The sha256 hashes are generated in cubin publishing script logs (accessible by codeowners).
119+
When updating the ArtifactPath for backend directories, update the corresponding hash.
120+
"""
121+
109122
TRTLLM_GEN_FMHA: str = (
110123
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
111124
)

flashinfer/jit/attention/modules.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
sm90a_nvcc_flags,
2929
current_compilation_context,
3030
)
31-
from ...jit.cubin_loader import get_cubin
31+
from ...jit.cubin_loader import get_cubin, get_meta_hash
3232
from ..utils import (
3333
dtype_map,
3434
filename_safe_dtype_map,
@@ -1568,15 +1568,21 @@ def gen_fmha_cutlass_sm100a_module(
15681568

15691569

15701570
def gen_trtllm_gen_fmha_module():
1571-
from ...artifacts import ArtifactPath, MetaInfoHash
1571+
from ...artifacts import ArtifactPath, CheckSumHash
15721572

15731573
include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include"
15741574
header_name = "flashInferMetaInfo"
15751575

1576+
# Check if checksums.txt exists in the cubin directory
1577+
checksum_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/checksums.txt"
1578+
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_FMHA)
1579+
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
1580+
1581+
meta_hash = get_meta_hash(checksum)
15761582
# use `get_cubin` to get "flashinferMetaInfo.h"
15771583
metainfo = get_cubin(
15781584
f"{include_path}/{header_name}.h",
1579-
MetaInfoHash.TRTLLM_GEN_FMHA,
1585+
meta_hash,
15801586
)
15811587

15821588
# make sure "flashinferMetaInfo.h" is downloaded or cached
@@ -1592,7 +1598,7 @@ def gen_trtllm_gen_fmha_module():
15921598
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
15931599
extra_cuda_cflags=[
15941600
f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"',
1595-
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{MetaInfoHash.TRTLLM_GEN_FMHA}\\"',
1601+
f'-DTLLM_GEN_FMHA_METAINFO_HASH=\\"{meta_hash}\\"',
15961602
],
15971603
)
15981604

flashinfer/jit/cubin_loader.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,16 @@ def download_file(
136136
return False
137137

138138

139-
def get_meta_hash(checksum_path: str) -> str:
139+
def get_meta_hash(checksums_bytes: bytes) -> str:
140140
"""
141-
Load the file from local cache (checksums.txt)
142-
and get the hash of corresponding flashinferMetaInfo.h file
141+
Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file
143142
"""
144-
local_path = FLASHINFER_CUBIN_DIR / safe_urljoin(checksum_path, "checksums.txt")
145-
with open(local_path, "r") as f:
146-
for line in f:
147-
sha256, filename = line.strip().split()
148-
if ".h" in filename:
149-
return sha256
150-
raise ValueError(f"Invalid path: checksums.txt not found in {checksum_path}")
143+
checksums_lines = checksums_bytes.decode("utf-8").splitlines()
144+
for line in checksums_lines:
145+
sha256, filename = line.strip().split()
146+
if ".h" in filename:
147+
return sha256
148+
raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
151149

152150

153151
def verify_cubin(cubin_path: str, expected_sha256: str) -> bool:

flashinfer/jit/fused_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from typing import List
1818

1919
from . import env as jit_env
20-
from ..artifacts import ArtifactPath, MetaInfoHash
20+
from ..artifacts import ArtifactPath, CheckSumHash
2121
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
2222
from .cpp_ext import is_cuda_version_at_least
23-
from .cubin_loader import get_cubin
23+
from .cubin_loader import get_cubin, get_meta_hash
2424
from .gemm.cutlass.generate_kernels import generate_gemm_operations
2525

2626

@@ -180,10 +180,16 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
180180
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
181181
header_name = "flashinferMetaInfo"
182182

183+
# Check if checksums.txt exists in the cubin directory
184+
checksum_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt"
185+
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_BMM)
186+
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
187+
meta_hash = get_meta_hash(checksum)
188+
183189
# use `get_cubin` to get "flashinferMetaInfo.h"
184190
metainfo = get_cubin(
185191
f"{include_path}/{header_name}.h",
186-
MetaInfoHash.TRTLLM_GEN_BMM,
192+
meta_hash,
187193
)
188194
# make sure "flashinferMetaInfo.h" is downloaded or cached
189195
assert metainfo, f"{header_name}.h not found"

flashinfer/jit/gemm/core.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jinja2
2121
import torch
2222

23-
from ...artifacts import ArtifactPath, MetaInfoHash
23+
from ...artifacts import ArtifactPath, CheckSumHash
2424
from .. import env as jit_env
2525
from ..core import (
2626
JitSpec,
@@ -30,7 +30,7 @@
3030
sm100f_nvcc_flags,
3131
current_compilation_context,
3232
)
33-
from ..cubin_loader import get_cubin
33+
from ..cubin_loader import get_cubin, get_meta_hash
3434
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different
3535

3636

@@ -361,10 +361,16 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
361361
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
362362
header_name = "flashinferMetaInfo"
363363

364+
# Check if checksums.txt exists in the cubin directory
365+
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
366+
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
367+
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
368+
meta_hash = get_meta_hash(checksum)
369+
364370
# use `get_cubin` to get "flashinferMetaInfo.h"
365371
metainfo = get_cubin(
366372
f"{include_path}/{header_name}.h",
367-
MetaInfoHash.TRTLLM_GEN_GEMM,
373+
meta_hash,
368374
)
369375
# make sure "flashinferMetaInfo.h" is downloaded or cached
370376
assert metainfo, f"{header_name}.h not found"
@@ -505,10 +511,16 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
505511
include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include"
506512
header_name = "flashinferMetaInfo"
507513

514+
# Check if checksums.txt exists in the cubin directory
515+
checksum_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/checksums.txt"
516+
checksum = get_cubin(checksum_path, CheckSumHash.TRTLLM_GEN_GEMM)
517+
assert checksum, f"Failed to get checksums.txt from {checksum_path}"
518+
meta_hash = get_meta_hash(checksum)
519+
508520
# use `get_cubin` to get "flashinferMetaInfo.h"
509521
metainfo = get_cubin(
510522
f"{include_path}/{header_name}.h",
511-
MetaInfoHash.TRTLLM_GEN_GEMM,
523+
meta_hash,
512524
)
513525
# make sure "flashinferMetaInfo.h" is downloaded or cached
514526
assert metainfo, f"{header_name}.h not found"

0 commit comments

Comments
 (0)