refactor: pull trtllm-gen batch-gemm/gemm headers from artifactory; update tma descriptor shape init#2235
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughPropagates mValidM/mValidN/mValidK alongside computed mM/mN/mK in multiple GEMM runner code paths, adds JIT header discovery/download/symlink utilities and integrates header fetching into JIT module assembly, and removes a large set of exported batched_gemm/gemm public headers (KernelParams, KernelTraits, Options, Interfaces, enums, and related declarations). Changes
Sequence Diagram(s)sequenceDiagram
participant JIT as JIT Generator
participant CL as CubinLoader (get_file / get_meta_hash / make_symlink)
participant FS as Filesystem/Cache (FLASHINFER_CUBIN_DIR)
participant COMP as Compiler
participant R as Runtime (C++ GEMM runners)
JIT->>CL: request header list & meta (get_meta_hash / get_available_header_files)
CL->>FS: check cache & validate sha256 (get_file)
alt header missing or checksum mismatch
CL->>FS: download header and write to cache
end
CL->>FS: create symlink to export/include dir (make_symlink)
JIT->>COMP: compile with extra_include_paths (FLASHINFER_CUBIN_DIR, ...)
COMP->>FS: read headers via symlink
COMP->>R: produce cubin/module
R->>R: use ProblemDimensions (mM/mN/mK and mValidM/mValidN/mValidK) at runtime
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request streamlines the management of TRTLLM-Gen headers by transitioning from bundled files to fetching them dynamically from Artifactory. It also includes crucial updates to how TMA problem dimensions are initialized in the CUDA GEMM runners, ensuring compatibility and correctness with the updated header fetching mechanism. The changes enhance the flexibility and maintainability of the artifact management system. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the build process to download trt-llm headers from an artifactory during JIT compilation, which is a commendable improvement for dependency management. The changes also update TMA descriptor initializations, likely to align with a newer API.
My main feedback revolves around improving maintainability. There's significant code duplication for downloading headers and creating symlinks across flashinfer/jit/fused_moe.py and flashinfer/jit/gemm/core.py, which should be refactored into a shared helper function. Additionally, the lists of header files are hardcoded, creating a maintenance burden; a more dynamic approach would be better. There are also opportunities to reduce duplication between get_available_header_files and get_available_cubin_files and to clean up commented-out code.
| def get_available_header_files( | ||
| source: str, retries: int = 3, delay: int = 5, timeout: int = 10 | ||
| ) -> tuple[str, ...]: | ||
| """ | ||
| Recursively navigates through child directories (e.g., include/) and finds | ||
| all *.h header files, returning them as a tuple of relative paths. | ||
| """ | ||
| result: list[str] = [] | ||
|
|
||
| def fetch_directory(url: str, prefix: str = "") -> None: | ||
| for attempt in range(1, retries + 1): | ||
| try: | ||
| response = requests.get(url, timeout=timeout) | ||
| response.raise_for_status() | ||
|
|
||
| # Find all .h header files in this directory | ||
| header_hrefs = re.findall(r'<a href="([^"]+\.h)">', response.text) | ||
| for h in header_hrefs: | ||
| result.append(prefix + h if prefix else h) | ||
|
|
||
| # Find all subdirectories (links ending with /) | ||
| dir_hrefs = re.findall(r'<a href="([^"]+/)">', response.text) | ||
| for d in dir_hrefs: | ||
| # Skip parent directory links | ||
| if d == "../" or d.startswith(".."): | ||
| continue | ||
| subdir_url = safe_urljoin(url, d) | ||
| subdir_prefix = prefix + d if prefix else d | ||
| fetch_directory(subdir_url, subdir_prefix) | ||
|
|
||
| return # Success, exit retry loop | ||
|
|
||
| except requests.exceptions.RequestException as e: | ||
| logger.warning( | ||
| f"Fetching available header files {url}: attempt {attempt} failed: {e}" | ||
| ) | ||
|
|
||
| if attempt < retries: | ||
| logger.info(f"Retrying in {delay} seconds...") | ||
| time.sleep(delay) | ||
|
|
||
| logger.error(f"Max retries reached for {url}. Fetch failed.") | ||
|
|
||
| fetch_directory(source) | ||
| logger.info(f"result: {result}") | ||
| return tuple(result) |
There was a problem hiding this comment.
This new function get_available_header_files has very similar logic to the existing get_available_cubin_files function. To improve maintainability and reduce code duplication, consider refactoring them into a single, more generic function. This new function could accept the file extension (e.g., .h or .cubin) as a parameter.
flashinfer/artifacts.py
Outdated
| TRTLLM_GEN_BMM: str = ( | ||
| "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" | ||
| "02546c924085adc5df7dc0a211cacc7ec3d3e01c/batched_gemm-0d275a2-9936841" | ||
| ) | ||
| # TRTLLM_GEN_BMM: str = ( | ||
| # "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" | ||
| # ) | ||
| TRTLLM_GEN_GEMM: str = ( | ||
| "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" | ||
| "02546c924085adc5df7dc0a211cacc7ec3d3e01c/gemm-0d275a2-30f1102" | ||
| ) | ||
| # TRTLLM_GEN_GEMM: str = ( | ||
| # "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" | ||
| # ) |
There was a problem hiding this comment.
flashinfer/artifacts.py
Outdated
| TRTLLM_GEN_BMM: str = ( | ||
| "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" | ||
| "680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90" | ||
| ) | ||
| # TRTLLM_GEN_BMM: str = ( | ||
| # "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" | ||
| # ) | ||
| DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" | ||
| TRTLLM_GEN_GEMM: str = ( | ||
| "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9" | ||
| "014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60" | ||
| ) | ||
| # TRTLLM_GEN_GEMM: str = ( | ||
| # "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9" | ||
| # ) |
There was a problem hiding this comment.
Please remove the commented-out lines containing old hashes to keep the code clean.
TRTLLM_GEN_BMM: str = (
"680167f34b532d493d3ed71da0a1640054cf1cb0a80cfca20e7d797dbd093a90"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
"014473f273a4dd248b5608e813f0fe468f05c686093577abd23f7a64afd77a60"
)| Otherwise, download the file from {uri_path} and write to {file_path}. | ||
| """ | ||
|
|
||
| file = load_cubin(file_path, sha256) |
There was a problem hiding this comment.
flashinfer/jit/fused_moe.py
Outdated
| header_files = [ | ||
| "BatchedGemmEnums.h", | ||
| "BatchedGemmInterface.h", | ||
| "BatchedGemmOptions.h", | ||
| "Enums.h", | ||
| "GemmGatedActOptions.h", | ||
| "GemmOptions.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "KernelTraits.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| ] |
There was a problem hiding this comment.
The list of header_files is hardcoded here and in other similar functions. This makes it difficult to maintain when the upstream trt-llm dependency adds or removes headers.
Consider making this more dynamic. You could, for example, use the new get_available_header_files function to fetch the list of headers from the artifactory directly, rather than hardcoding them.
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/jit/gemm/core.py (1)
565-607: Consider extracting shared header download logic and verify include path consistency.The header file download logic (lines 566-592) is duplicated between
gen_trtllm_gen_gemm_moduleandgen_trtllm_low_latency_gemm_module. Additionally, there's an inconsistency inextra_include_paths:gen_trtllm_gen_gemm_moduleincludes bothFLASHINFER_CUBIN_DIRandFLASHINFER_CUBIN_DIR / include_path(lines 424-427), whilegen_trtllm_low_latency_gemm_moduleonly includesFLASHINFER_CUBIN_DIR / include_path(line 607).
- Extract the header download logic into a shared helper function
- Verify whether the include path difference is intentional or an oversight - if the symlink at
flashinfer/trtllm/gemm/trtllmGen_gemm_exportis needed, thenFLASHINFER_CUBIN_DIRshould be included here as wellExample refactor:
def download_trtllm_gemm_headers(artifact_path: str, checksum_hash: str): """Download and cache TRTLLM GEMM export headers.""" include_path = f"{artifact_path}/include" checksum_path = f"{artifact_path}/checksums.txt" checksum = get_cubin(checksum_path, checksum_hash) assert checksum, f"Failed to get checksums.txt from {checksum_path}" meta_hash = get_meta_hash(checksum) header_name = "flashinferMetaInfo" metainfo = get_cubin(f"{include_path}/{header_name}.h", meta_hash) assert metainfo, f"{header_name}.h not found" header_files = [ "GemmInterface.h", "GemmOptions.h", # ... rest of the list ] header_path = f"{include_path}/trtllmGen_gemm_export" for file in header_files: uri_path = f"{header_path}/{file}" file_hash = get_meta_hash(checksum, file) file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file get_file(uri_path, file_hash, str(file_path)) symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm") make_symlink("../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export") return include_path
🧹 Nitpick comments (2)
flashinfer/jit/cubin_loader.py (1)
139-151: Clarify the inline comment.The comment mentions "case-insensitive for the 'I' in Infer" but the implementation performs case-insensitive matching on the entire filename using
.lower().endswith(target_file.lower()). This may confuse future maintainers.Consider updating the comment to:
- # Match specifically flashinferMetaInfo.h (case-insensitive for the 'I' in Infer) + # Match filename ending with target_file (case-insensitive)flashinfer/artifacts.py (1)
82-127: Consider reducing log verbosity.Line 126 logs the entire list of header files at INFO level, which could be verbose for directories with many headers. Consider logging the count instead or moving this to DEBUG level.
- logger.info(f"result: {result}") + logger.info(f"Found {len(result)} header files") + logger.debug(f"Header files: {result}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (11)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**
📒 Files selected for processing (24)
csrc/trtllm_batched_gemm_runner.cu(2 hunks)csrc/trtllm_gemm_runner.cu(3 hunks)csrc/trtllm_low_latency_gemm_runner.cu(1 hunks)flashinfer/artifacts.py(5 hunks)flashinfer/jit/cubin_loader.py(2 hunks)flashinfer/jit/fused_moe.py(3 hunks)flashinfer/jit/gemm/core.py(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h(0 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h(0 hunks)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h(0 hunks)
💤 Files with no reviewable changes (17)
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
- include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/trtllm_gemm_runner.cucsrc/trtllm_batched_gemm_runner.cu
🧬 Code graph analysis (3)
flashinfer/jit/fused_moe.py (1)
flashinfer/jit/cubin_loader.py (4)
get_cubin(227-246)get_meta_hash(139-151)make_symlink(195-203)get_file(206-224)
flashinfer/jit/gemm/core.py (1)
flashinfer/jit/cubin_loader.py (3)
get_meta_hash(139-151)make_symlink(195-203)get_file(206-224)
flashinfer/artifacts.py (1)
flashinfer/jit/cubin_loader.py (1)
safe_urljoin(38-42)
🪛 Ruff (0.14.8)
flashinfer/jit/cubin_loader.py
209-209: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
flashinfer/artifacts.py
112-112: Consider moving this statement to an else block
(TRY300)
🔇 Additional comments (15)
csrc/trtllm_low_latency_gemm_runner.cu (1)
52-54: LGTM!The initialization of the "valid" dimension fields mirrors the primary dimension fields, which aligns with the PR's objective to update TMA descriptor shape initialization.
csrc/trtllm_gemm_runner.cu (3)
122-124: LGTM!The valid dimension fields are correctly initialized after the primary dimensions, maintaining consistency with the transpose logic already applied to mM and mN.
145-147: LGTM!The initialization pattern is consistent with the getWorkspaceSizeInBytes method.
196-198: LGTM!The initialization pattern is uniformly applied across all three methods (getWorkspaceSizeInBytes, run, and getValidTactics).
csrc/trtllm_batched_gemm_runner.cu (3)
149-151: LGTM!The valid dimension fields are correctly initialized in the workspace sizing method.
343-345: LGTM!The valid dimension fields are correctly initialized in the config validation method.
448-450: LGTM!The valid dimension fields are correctly initialized in the config validation method.
flashinfer/artifacts.py (4)
139-150: LGTM!The artifact paths have been updated with new hashes, and the old values are preserved as comments for reference or rollback. This aligns with the PR objective to pull trtllm-gen headers from Artifactory.
166-178: LGTM!The checksums have been updated to match the new artifact paths. Good practice to preserve the old values as comments.
245-247: LGTM!The header files are now included in the artifact listing, mirroring the pattern used for cubin files. This enables header files to be downloaded and verified alongside cubin binaries.
256-256: LGTM!Adding explicit type annotation improves code clarity.
flashinfer/jit/gemm/core.py (2)
33-38: LGTM!The new imports support the header file management functionality added in this PR.
424-427: LGTM!The include paths are extended to support both the root cubin directory (for the symlink at
flashinfer/trtllm/gemm/trtllmGen_gemm_export) and the artifact-specific include directory. This enables the compilation to locate the downloaded headers.flashinfer/jit/fused_moe.py (2)
29-29: LGTM!The new imports support the header file management functionality added in this PR.
298-303: LGTM!The include paths are extended to support both the root cubin directory and the artifact-specific include directory, consistent with the pattern in
gen_trtllm_gen_gemm_module.
flashinfer/jit/gemm/core.py
Outdated
|
|
||
| header_files = [ | ||
| "GemmInterface.h", | ||
| "GemmOptions.h", | ||
| "Enums.h", | ||
| "KernelTraits.h", | ||
| "KernelParams.h", | ||
| "KernelParamsDecl.h", | ||
| "TmaDescriptor.h", | ||
| "trtllm/gen/CommonUtils.h", | ||
| "trtllm/gen/CudaKernelLauncher.h", | ||
| "trtllm/gen/DtypeDecl.h", | ||
| "trtllm/gen/MmaDecl.h", | ||
| "trtllm/gen/SfLayoutDecl.h", | ||
| "trtllm/gen/CudaArchDecl.h", | ||
| ] | ||
|
|
||
| header_path = f"{include_path}/trtllmGen_gemm_export" | ||
| for file in header_files: | ||
| uri_path = f"{header_path}/{file}" | ||
| file_hash = get_meta_hash(checksum, file) | ||
| file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file | ||
| get_file(uri_path, file_hash, file_path) | ||
| # Create directory flashinfer/trtllm/gemm/trtllmGen_gemm_export pointing to trtllmGen_gemm_export | ||
| symlink_parent = str(jit_env.FLASHINFER_CUBIN_DIR / "flashinfer/trtllm/gemm") | ||
| make_symlink( | ||
| "../../../trtllmGen_gemm_export", symlink_parent, "trtllmGen_gemm_export" | ||
| ) |
There was a problem hiding this comment.
Add error handling for header file downloads.
The get_file function returns empty bytes on failure (from load_cubin at line 224 of cubin_loader.py), but there's no verification that the download succeeded. This could lead to compilation failures with cryptic error messages.
Consider adding validation after the download loop:
header_path = f"{include_path}/trtllmGen_gemm_export"
downloaded_files = []
for file in header_files:
uri_path = f"{header_path}/{file}"
file_hash = get_meta_hash(checksum, file)
file_path = jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export" / file
result = get_file(uri_path, file_hash, str(file_path))
if not result:
raise RuntimeError(f"Failed to download header file: {file}")
downloaded_files.append(file)🤖 Prompt for AI Agents
In flashinfer/jit/gemm/core.py around lines 382 to 409, the loop that calls
get_file for each header does not verify the download succeeded (get_file can
return empty bytes), so add immediate validation after each get_file call:
capture the return value, if it is falsy/empty raise a RuntimeError naming the
missing header (e.g. f"Failed to download header file: {file}"), and optionally
collect successful filenames into a list; ensure you pass the correct path/type
to get_file as required by its signature and fail fast to avoid later cryptic
compilation errors.
There was a problem hiding this comment.
@jimmyzho can you verify if this error check issue is valid? let's address error checks if so
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
|
@aleozlx can you help reviewing this PR? |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/jit/gemm/core.py (1)
652-654: Inconsistent include paths between GEMM module generators.
gen_trtllm_gen_gemm_module(lines 470-473) includes bothjit_env.FLASHINFER_CUBIN_DIRandjit_env.FLASHINFER_CUBIN_DIR / include_path, butgen_trtllm_low_latency_gemm_moduleonly includesjit_env.FLASHINFER_CUBIN_DIR / include_path. This inconsistency may cause header resolution failures since the symlink atflashinfer/trtllm/gemm/trtllmGen_gemm_exportrequires the base cubin directory in the include path.Proposed fix
# link "include" sub-directory in cache - extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], + extra_include_paths=[ + jit_env.FLASHINFER_CUBIN_DIR, + jit_env.FLASHINFER_CUBIN_DIR / include_path, + ], extra_ldflags=["-lcuda"],
🤖 Fix all issues with AI agents
In `@flashinfer/artifacts.py`:
- Around line 82-127: The recursive fetch_directory inside
get_available_header_files currently swallows failures after max retries,
causing silent incomplete results; modify fetch_directory to raise an exception
when max retries are reached (include the URL and the last caught exception)
instead of just logging and returning, and allow that exception to propagate out
of get_available_header_files so callers are aware of failures (capture the last
requests.exceptions.RequestException in the except block and raise a
RuntimeError or re-raise the original exception with contextual message).
In `@flashinfer/jit/gemm/core.py`:
- Around line 629-633: get_file is receiving a pathlib.Path for file_path which
causes a type mismatch; change the call in the loop to pass a string instead of
a Path. In the for loop that iterates header_files (using header_path, checksum,
get_meta_hash, and jit_env.FLASHINFER_CUBIN_DIR / "trtllmGen_gemm_export"),
convert file_path to str (e.g., str(file_path)) when calling get_file(uri_path,
file_hash, file_path) so the third argument matches get_file's expected str
type.
🧹 Nitpick comments (2)
flashinfer/jit/gemm/core.py (1)
429-455: Consider extracting duplicated header download logic.The header file download and symlink creation logic is nearly identical between
gen_trtllm_gen_gemm_module(lines 429-455) andgen_trtllm_low_latency_gemm_module(lines 612-638), and also similar togen_trtllm_gen_fused_moe_sm100_moduleinfused_moe.py. Consider extracting this into a shared helper function to reduce duplication and ensure consistent error handling.Also applies to: 612-638
flashinfer/artifacts.py (1)
244-244: Unusual type annotation syntax.The syntax
list[tuple[str, str]](...)works but is unconventional. The more common approach is justlist(...)with type hints elsewhere, or a simple cast if needed.Alternative syntax
- cubin_files = list[tuple[str, str]](get_subdir_file_list()) + cubin_files: list[tuple[str, str]] = list(get_subdir_file_list())
| def get_available_header_files( | ||
| source: str, retries: int = 3, delay: int = 5, timeout: int = 10 | ||
| ) -> tuple[str, ...]: | ||
| """ | ||
| Recursively navigates through child directories (e.g., include/) and finds | ||
| all *.h header files, returning them as a tuple of relative paths. | ||
| """ | ||
| result: list[str] = [] | ||
|
|
||
| def fetch_directory(url: str, prefix: str = "") -> None: | ||
| for attempt in range(1, retries + 1): | ||
| try: | ||
| response = requests.get(url, timeout=timeout) | ||
| response.raise_for_status() | ||
|
|
||
| # Find all .h header files in this directory | ||
| header_hrefs = re.findall(r'<a href="([^"]+\.h)">', response.text) | ||
| for h in header_hrefs: | ||
| result.append(prefix + h if prefix else h) | ||
|
|
||
| # Find all subdirectories (links ending with /) | ||
| dir_hrefs = re.findall(r'<a href="([^"]+/)">', response.text) | ||
| for d in dir_hrefs: | ||
| # Skip parent directory links | ||
| if d == "../" or d.startswith(".."): | ||
| continue | ||
| subdir_url = safe_urljoin(url, d) | ||
| subdir_prefix = prefix + d if prefix else d | ||
| fetch_directory(subdir_url, subdir_prefix) | ||
|
|
||
| return # Success, exit retry loop | ||
|
|
||
| except requests.exceptions.RequestException as e: | ||
| logger.warning( | ||
| f"Fetching available header files {url}: attempt {attempt} failed: {e}" | ||
| ) | ||
|
|
||
| if attempt < retries: | ||
| logger.info(f"Retrying in {delay} seconds...") | ||
| time.sleep(delay) | ||
|
|
||
| logger.error(f"Max retries reached for {url}. Fetch failed.") | ||
|
|
||
| fetch_directory(source) | ||
| logger.info(f"result: {result}") | ||
| return tuple(result) |
There was a problem hiding this comment.
Silent failures in recursive directory traversal may cause incomplete results.
The nested fetch_directory function logs errors but doesn't propagate failures - it silently continues, potentially returning incomplete results. If a subdirectory fails to fetch after retries, callers won't know some headers are missing. Consider either:
- Raising an exception after max retries
- Returning a success indicator along with results
Proposed fix to raise on failure
if attempt < retries:
logger.info(f"Retrying in {delay} seconds...")
time.sleep(delay)
- logger.error(f"Max retries reached for {url}. Fetch failed.")
+ logger.error(f"Max retries reached for {url}. Fetch failed.")
+ raise RuntimeError(f"Failed to fetch header files from {url}")
fetch_directory(source)🧰 Tools
🪛 Ruff (0.14.14)
112-112: Consider moving this statement to an else block
(TRY300)
🤖 Prompt for AI Agents
In `@flashinfer/artifacts.py` around lines 82 - 127, The recursive fetch_directory
inside get_available_header_files currently swallows failures after max retries,
causing silent incomplete results; modify fetch_directory to raise an exception
when max retries are reached (include the URL and the last caught exception)
instead of just logging and returning, and allow that exception to propagate out
of get_available_header_files so callers are aware of failures (capture the last
requests.exceptions.RequestException in the except block and raise a
RuntimeError or re-raise the original exception with contextual message).
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/jit/cubin_loader.py`:
- Around line 139-151: The ValueError in get_meta_hash incorrectly hardcodes
"flashinferMetaInfo.h" in its message even though the function accepts a
target_file parameter; update the exception raised in get_meta_hash to include
the actual target_file (e.g., use the target_file variable in the message) so
the error reflects which filename was not found and aids debugging.
| def get_meta_hash( | ||
| checksums_bytes: bytes, target_file: str = "flashinferMetaInfo.h" | ||
| ) -> str: | ||
| """ | ||
| Parse the checksums.txt file and get the hash of corresponding flashinferMetaInfo.h file | ||
| """ | ||
| checksums_lines = checksums_bytes.decode("utf-8").splitlines() | ||
| for line in checksums_lines: | ||
| sha256, filename = line.strip().split() | ||
| if ".h" in filename: | ||
| # Match specifically flashinferMetaInfo.h (case-insensitive for the 'I' in Infer) | ||
| if filename.lower().endswith(target_file.lower()): | ||
| return sha256 | ||
| raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found") |
There was a problem hiding this comment.
Update error message to include the actual target file.
The error message on line 151 is hardcoded to mention flashinferMetaInfo.h, but the function now accepts a configurable target_file parameter. When searching for a different file (e.g., "GemmInterface.h"), the error message will be misleading.
Proposed fix
- raise ValueError("Invalid checksums.txt, no flashinferMetaInfo.h found")
+ raise ValueError(f"Invalid checksums.txt, no {target_file} found")🧰 Tools
🪛 Ruff (0.14.14)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@flashinfer/jit/cubin_loader.py` around lines 139 - 151, The ValueError in
get_meta_hash incorrectly hardcodes "flashinferMetaInfo.h" in its message even
though the function accepts a target_file parameter; update the exception raised
in get_meta_hash to include the actual target_file (e.g., use the target_file
variable in the message) so the error reflects which filename was not found and
aids debugging.
|
/bot run |
|
/bot run |
|
[FAILED] Pipeline #44496192: 13/20 passed |
|
there seems to be some failures https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/271105330 |
|
/bot run |
|
[FAILED] Pipeline #44675585: 14/20 passed |
…le to break circular import artifacts.py imports from jit.cubin_loader at module level, which triggers jit/__init__.py to load, which imports gen_moe_utils_module from this file, which then tries to import ArtifactPath/CheckSumHash from artifacts — but artifacts.py hasn't finished initializing yet, so those names don't exist. Fixes ImportError introduced in flashinfer-ai#2235. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
#2681) …le to break circular import artifacts.py imports from jit.cubin_loader at module level, which triggers jit/__init__.py to load, which imports gen_moe_utils_module from this file, which then tries to import ArtifactPath/CheckSumHash from artifacts — but artifacts.py hasn't finished initializing yet, so those names don't exist. Fixes ImportError introduced in #2235. See [Internal Pipeline](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/pipelines/45223374) for failure example. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved internal code organization to enhance maintainability and prevent potential dependency issues. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…pdate tma descriptor shape init (flashinfer-ai#2235) <!-- .github/pull_request_template.md --> ## 📌 Description ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Ensured consistent validation/propagation of GEMM/batched-GEMM dimensions (M/N/K) across code paths. * **New Features** * JIT tooling now discovers, downloads, verifies and caches required header dependencies and creates symlinks to ensure reliable module compilation. * Added file-fetching and symlink helpers used by the JIT pipeline. * **Refactor** * Removed legacy GEMM/batched-GEMM public headers and interfaces to streamline the public API surface. * **Chores** * Updated prebuilt GEMM artifact paths and checksums. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
flashinfer-ai#2681) …le to break circular import artifacts.py imports from jit.cubin_loader at module level, which triggers jit/__init__.py to load, which imports gen_moe_utils_module from this file, which then tries to import ArtifactPath/CheckSumHash from artifacts — but artifacts.py hasn't finished initializing yet, so those names don't exist. Fixes ImportError introduced in flashinfer-ai#2235. See [Internal Pipeline](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/pipelines/45223374) for failure example. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved internal code organization to enhance maintainability and prevent potential dependency issues. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Refactor
Chores