[Spark bug] Fix arch 12.1 -> "sm120a" flag for Spark, CUDA 12.9#2839
[Spark bug] Fix arch 12.1 -> "sm120a" flag for Spark, CUDA 12.9#2839aleozlx merged 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical compilation issue affecting Spark GPUs (SM 12.1) when used with CUDA 12.9. The change ensures that the correct architecture flag ('120f') is applied during compilation, preventing runtime failures caused by incompatible kernel images. This update improves the robustness of the compilation process for specific GPU architectures and CUDA versions. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
📝 WalkthroughWalkthroughAdjusted SM 12.x CUDA architecture normalization to require CUDA >= 12.9 by checking version directly and raising a RuntimeError when unmet; removed previous exception-based fallback path and simplified import/usage of Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a bug related to CUDA architecture flags for Spark on CUDA 12.9 by lowering the required version from 13.0. However, the change introduces a critical circular import by moving an import statement to the top level of flashinfer/compilation_context.py. This will cause the application to fail on startup with an ImportError. I've added comments with suggestions to resolve this by reverting to a local import within the function where it's used, which is a common pattern for breaking such dependency cycles.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/compilation_context.py (1)
48-51: Correct fix for the CUDA 12.9/Spark compilation issue.The threshold change from
"13.0"to"12.9"correctly ensures Spark (SM 12.1) and SM 12.0 get the"0f"suffix on CUDA 12.9. The fail-fast RuntimeError is better than producing non-functional"120a"binaries that fail at runtime with "no kernel image is available".Consider making the error message slightly more actionable:
💡 Optional: More descriptive error message
if is_cuda_version_at_least("12.9"): return (major, "0f") else: - raise RuntimeError("SM 12.x requires CUDA >= 12.9") + raise RuntimeError( + "SM 12.x (Spark/Thor) requires CUDA >= 12.9. " + "Please upgrade your CUDA toolkit." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/compilation_context.py` around lines 48 - 51, The branch handling SM 12.x should keep the CUDA threshold at "12.9": when is_cuda_version_at_least("12.9") return (major, "0f"), otherwise raise a RuntimeError; update the RuntimeError in the same block (where (major, "0f") is returned and is_cuda_version_at_least is called) to a clearer message such as "CUDA >= 12.9 is required to compile for SM 12.x (avoids producing non-functional SM 12.0/12.1 binaries)"; locate and modify this logic in compilation_context.py where is_cuda_version_at_least and the (major, "0f") tuple are used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/compilation_context.py`:
- Around line 48-51: The branch handling SM 12.x should keep the CUDA threshold
at "12.9": when is_cuda_version_at_least("12.9") return (major, "0f"), otherwise
raise a RuntimeError; update the RuntimeError in the same block (where (major,
"0f") is returned and is_cuda_version_at_least is called) to a clearer message
such as "CUDA >= 12.9 is required to compile for SM 12.x (avoids producing
non-functional SM 12.0/12.1 binaries)"; locate and modify this logic in
compilation_context.py where is_cuda_version_at_least and the (major, "0f")
tuple are used.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1766d240-187b-47d9-827b-e450d27b0a41
📒 Files selected for processing (1)
flashinfer/compilation_context.py
|
/bot run |
|
[FAILED] Pipeline #46625067: 14/20 passed |
|
@kahyunnam if you'd like to correct the PR desc about Thor |
📌 Description
Bug found in nightly [Spark, 12.9] matrix https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/285092631, where Spark compiles to "120a" (see "/tmp/.cache/flashinfer/0.6.6/120a/" path in log below).
Root cause was #2725 , where we added logic for compiling both Spark and RTX Pro 6000 to 120f, but on the condition that cuda version is 13 or higher. Lower (12.9) defaults to 'a' suffix, 120a.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit