-
Notifications
You must be signed in to change notification settings - Fork 449
[Refactor] Use cuda capability from torch to be more generic #1557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughget_target_arch now accepts either a compute-version string or a (major, minor) tuple; compute-version parsing for 3-character arch strings is simplified and trailing "a"/"f" suffixes are stripped. determine_target detects CUDA at runtime via torch and can return a structured Target-like object with kind "cuda" and an Changes
Sequence Diagram(s)(omitted — changes are localized parsing and runtime detection without new multi-component interaction flows) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-12-24T17:20:32.819ZApplied to files:
🪛 Ruff (0.14.10)tilelang/contrib/nvcc.py418-418: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
🔇 Additional comments (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
d2b1020 to
d920eef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/contrib/nvcc.pytilelang/utils/target.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-24T17:20:32.819Z
Learnt from: clouds56
Repo: tile-ai/tilelang PR: 1527
File: tilelang/env.py:0-0
Timestamp: 2025-12-24T17:20:32.819Z
Learning: The nvidia-cuda-nvcc PyPI package installs to `nvidia/cu13/bin/` (for CUDA 13), `nvidia/cu12/bin/` (for CUDA 12), and `nvidia/cu11/bin/` (for CUDA 11) in the site-packages directory, not to `nvidia/cuda_nvcc/bin/`. These paths should be used when detecting CUDA installations from PyPI packages in tilelang/env.py.
Applied to files:
tilelang/contrib/nvcc.py
🧬 Code graph analysis (2)
tilelang/utils/target.py (2)
tilelang/language/ast/ir.py (1)
target(1677-1707)tilelang/contrib/nvcc.py (1)
get_target_arch(452-460)
tilelang/contrib/nvcc.py (1)
tilelang/contrib/rocm.py (1)
parse_compute_version(179-201)
🪛 Ruff (0.14.10)
tilelang/contrib/nvcc.py
418-418: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (4)
tilelang/utils/target.py (2)
2-4: Verify torch dependency is acceptable.Adding
torchas an unconditional import increases the package's dependency footprint. While this aligns with the PR objective to use PyTorch's CUDA capability detection, ensure this is acceptable for all use cases of tilelang.Consider whether a lazy import or optional dependency pattern would be more appropriate if torch isn't always available in deployment environments.
89-89: Update the docstring to document the dict[str, str] return type.The function signature correctly includes
dict[str, str]in the return type (when CUDA is auto-detected with capabilities), but the docstring at line 98-99 only documentsUnion[str, Target]and omits the dict case. This creates a documentation mismatch. All callers already handle this correctly—they either wrap the result inTarget()(which accepts dict), usereturn_object=True(which returns Target), or pass toTarget.canon_target(). Update the docstring to reflect the actual return type.Likely an incorrect or invalid review comment.
tilelang/contrib/nvcc.py (2)
452-460: LGTM! Enhanced flexibility for compute version input.The function now correctly handles both string and tuple inputs, which enables seamless integration with
torch.cuda.get_device_capability()that returns tuples. The logic correctly:
- Parses string format via existing
parse_compute_version()- Accepts tuples directly from PyTorch
- Maintains the 'a' suffix for architectures >= 9.0
408-418: The architecture parsing logic correctly handles both current and future NVIDIA GPU architectures. NVIDIA's upcoming Blackwell architecture uses 3-digit compute capabilities (sm_100 family), confirming that the code's handling of 3-character architectures is sound and forward-looking, not speculative. The rstrip("af") removal of suffixes followed by length-based parsing (2-char for current sm_70/80/90 architectures, 3-char for future sm_100+) is appropriate and requires no changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/utils/target.py (1)
118-121: Add error handling for torch CUDA device access.As noted in a previous review,
torch.cuda.get_device_capability(0)can raise exceptions if device 0 is not accessible or PyTorch CUDA initialization fails, even whentorch.cuda.is_available()returnsTrue. This should be wrapped in a try/except to gracefully fall back to the plain"cuda"target.🔎 Proposed fix with error handling
if is_cuda_available: - if torch.cuda.is_available() and (cap := torch.cuda.get_device_capability(0)): - return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"}) - else: - return_var = "cuda" + try: + import torch + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + cap = torch.cuda.get_device_capability(0) + if cap: + return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"}) + else: + return_var = "cuda" + else: + return_var = "cuda" + except Exception: + # Fall back to plain "cuda" if torch is unavailable or CUDA detection fails + return_var = "cuda"tilelang/contrib/nvcc.py (1)
418-418: Fix the error message format string.The error message is missing the
fprefix and the variable in the format placeholder. This was noted in a previous review but appears unaddressed.🔎 Proposed fix
else: - raise ValueError("Unsupported arch: {}") + raise ValueError(f"Unsupported arch: {arch}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/contrib/nvcc.pytilelang/utils/target.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/utils/target.py (1)
tilelang/contrib/nvcc.py (1)
get_target_arch(452-460)
tilelang/contrib/nvcc.py (1)
tilelang/contrib/rocm.py (1)
parse_compute_version(179-201)
🪛 Ruff (0.14.10)
tilelang/contrib/nvcc.py
418-418: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
tilelang/contrib/nvcc.py (2)
452-460: LGTM! Signature change to accept tuple is well-integrated.The updated
get_target_archcorrectly handles both string compute versions (e.g.,"8.6") and tuples fromtorch.cuda.get_device_capability()(e.g.,(8, 6)). The math and suffix logic remain correct:
(8, 6)→"86"(9, 0)→"90a"
408-416: Arch parsing simplification looks correct.The
rstrip("af")properly handles architecture suffixes (e.g.,"90a"→"90","120a"→"120"), and the 3-character case correctly parses newer architectures like sm_120 to"12.0".
This shall reduce confusion caused by the default
sm_50from tvm when failed to detect arch.Summary by CodeRabbit
Refactor
New Behavior
✏️ Tip: You can customize this high-level summary in your review settings.