feat: enable deepgemm jit for fp8 block-scale on SM90#1969
feat: enable deepgemm jit for fp8 block-scale on SM90#1969yzh119 merged 3 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe pull request modifies the JIT compilation configuration for the TensorRT LLM deep GEMM module. The include directory discovery mechanism now uses the flashinfer-python package instead of tensorrt_llm, with updated path resolution. Additionally, the NVCC usage logic is simplified to always use NVCC without reading an environment variable. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
Summary of ChangesHello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates JIT compilation capabilities for FP8 DeepGEMM kernels into the FlashInfer project. The primary goal is to enhance performance by allowing dynamic compilation of these kernels. This is achieved by ensuring that NVCC is always used for JIT compilation and by correctly configuring the include paths to resolve dependencies within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables JIT compilation for FP8 DeepGEMM kernels, defaulting to NVCC. The changes correctly update package paths from tensorrt_llm to flashinfer-python to locate necessary headers. My review focuses on improving code clarity and maintainability. I've suggested updating a log message to reflect the package change and cleaning up commented-out code while preserving important context in a TODO comment.
| static bool kJitUseNvcc = []() { | ||
| char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| // return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // always use nvcc | ||
| // TODO: Enable nvrtc -- need these headers: | ||
| // [TensorRT-LLM][INFO] Compilation log: | ||
| // kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h" | ||
| return true; | ||
| }(); |
There was a problem hiding this comment.
The kJitUseNvcc variable is now hardcoded to true, and the previous implementation that used an environment variable is commented out. While the TODO comment explains why NVRTC is currently disabled, leaving commented-out code can reduce readability and maintainability. It would be cleaner to remove the commented-out lines and refine the comments to keep only the essential context.
static bool kJitUseNvcc = []() {
// NVRTC is currently disabled due to include issues. Using NVCC by default.
// TODO: Enable NVRTC. It fails with errors like:
// [TensorRT-LLM][INFO] Compilation log:
// kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h"
return true;
}();
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh (2)
154-179: Update the outdated comment to reflect the new package name.The comment at line 155 still references
tensorrt_llm, but the code now usesflashinfer-python. Update the comment for consistency.Apply this diff to update the comment:
// Parse the location using regex - // `pip show tensorrt_llm` will output something like: + // `pip show flashinfer-python` will output something like: // Location: /usr/local/lib/python3.12/dist-packages // Editable project location: /code
180-182: Update the error message to reflect the new package dependency.The error message references "TensorRT LLM installation" but the code now searches for
flashinfer-python. Update the message for consistency.Apply this diff to update the error message:
} else { - TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); + TLLM_LOG_WARNING("Failed to find flashinfer-python installation, DeepGEMM will be disabled."); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh(2 hunks)csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh(1 hunks)
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh (1)
38-46: LGTM! Appropriate simplification to force NVCC usage.The hardcoded
return truewith commented-out environment variable logic is appropriate given the TODO note about missing headers for NVRTC. This ensures consistent behavior until NVRTC support is fully enabled.
| if (includeDirs.empty()) { | ||
| // Command to execute | ||
| char const* cmd = "pip show tensorrt_llm 2>/dev/null"; | ||
| char const* cmd = "pip show flashinfer-python 2>/dev/null"; |
There was a problem hiding this comment.
What's the purpose of this command?
There was a problem hiding this comment.
For the DeepGEMM JIT, it needs the header files in deep_gemm/, this command finds the installation path which is then used further down to add the deep_gemm/ to the -I
There was a problem hiding this comment.
I tend to move the logic to python, pip show flashinfer-python doesn't necessarily show the correct package information (e.g. at AOT time when the package is not installed yet).
There was a problem hiding this comment.
Or we can obtain the include path from python and pass the value to C++.
There was a problem hiding this comment.
I think this is where a refactor might be necessary, unfortunately these deep_gemm kernels aren't captured as part of AOT.
| } | ||
| } else { | ||
| TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); | ||
| TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); |
There was a problem hiding this comment.
I guess we can safely assume flashinfer is installed if this function is called?
<!-- .github/pull_request_template.md --> ## 📌 Description This PR implements the refactor mentioned in https://github.com/flashinfer-ai/flashinfer/pull/1969/files#r2461856020 In our current design we rely on calling `pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null` to obtain deepgemm jit include directory, which is error-prune (e.g. if user do not have `pip` available in the environment it will fail), in this PR we pass the deepgemm jit include directory through python APIs. ## 🔍 Related Issues #1969 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @djmmoss @jiahanc @nvmbreughe <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Modules now set DeepGEMM JIT include directories at runtime so fused MoE modules have correct JIT include paths during initialization. * **Chores** * JIT compiler API and module build updated to accept and propagate externally provided include directories. * Minor header/build adjustments to support the new initialization flow. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…hinfer-ai#2090) <!-- .github/pull_request_template.md --> ## 📌 Description This PR implements the refactor mentioned in https://github.com/flashinfer-ai/flashinfer/pull/1969/files#r2461856020 In our current design we rely on calling `pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null` to obtain deepgemm jit include directory, which is error-prune (e.g. if user do not have `pip` available in the environment it will fail), in this PR we pass the deepgemm jit include directory through python APIs. ## 🔍 Related Issues flashinfer-ai#1969 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @djmmoss @jiahanc @nvmbreughe <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Modules now set DeepGEMM JIT include directories at runtime so fused MoE modules have correct JIT include paths during initialization. * **Chores** * JIT compiler API and module build updated to accept and propagate externally provided include directories. * Minor header/build adjustments to support the new initialization flow. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…1969) <!-- .github/pull_request_template.md --> ## 📌 Description Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently disabled it uses NVCC by default. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * JIT include directory discovery now uses the flashinfer-python package instead of the previous package. * Updated resolved include path to the flashinfer data location. * Runtime compilation now consistently uses NVCC; the prior environment-variable toggle was removed. * Updated warning text when the expected package installation cannot be found. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duncan Moss <djm.moss@gmail.com>
…hinfer-ai#2090) <!-- .github/pull_request_template.md --> ## 📌 Description This PR implements the refactor mentioned in https://github.com/flashinfer-ai/flashinfer/pull/1969/files#r2461856020 In our current design we rely on calling `pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null` to obtain deepgemm jit include directory, which is error-prune (e.g. if user do not have `pip` available in the environment it will fail), in this PR we pass the deepgemm jit include directory through python APIs. ## 🔍 Related Issues flashinfer-ai#1969 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @djmmoss @jiahanc @nvmbreughe <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Modules now set DeepGEMM JIT include directories at runtime so fused MoE modules have correct JIT include paths during initialization. * **Chores** * JIT compiler API and module build updated to accept and propagate externally provided include directories. * Minor header/build adjustments to support the new initialization flow. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently disabled it uses NVCC by default.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit