-
Notifications
You must be signed in to change notification settings - Fork 450
[Doc] Optimize the quickstart guide for clarity and not just for CUDA #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ize in benchmark script
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughREADME and example updated: inner kernel renamed to matmul_relu_kernel, a ReLU applied after GEMM, JIT wrapper changed to @tilelang.jit(target="cuda"), matmul now returns the renamed kernel, and profiling/usage paths adjusted accordingly. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant J as matmul (JIT wrapper)
participant K as matmul_relu_kernel
participant G as CUDA GPU
participant P as Profiler
participant T as PyTorch
U->>J: call matmul(M,N,K,dtype)
J-->>U: returns K (matmul_relu_kernel)
U->>K: launch(A,B,C)
K->>G: execute kernel
rect rgba(220,235,255,0.35)
note over G: Tile load & GEMM
G-->>G: accumulate into C_local
G-->>G: apply ReLU to C_local
G-->>G: store C_local -> C
end
G-->>U: C populated
U->>T: compute ref = relu(A @ B)
T-->>U: ref tensor
U-->>U: compare C vs ref
U->>K: K.get_profiler(tensor_supply=...)
K-->>P: profiler
U->>P: run profiling
P-->>U: latency metrics
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the quickstart guide and its accompanying example code to improve user comprehension and highlight the framework's broader applicability. By removing CUDA-specific optimization details and introducing a Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does a great job of optimizing the quickstart guide for clarity and making it more platform-agnostic. The changes in README.md and examples/quickstart.py simplify the code by removing CUDA-specific details and showcasing a more streamlined API with the @tilelang.jit decorator. The addition of the ReLU operation makes the example more practical. I've added a couple of minor suggestions for the README.md file to fix some inconsistencies in the comments and code examples to ensure they are correct and easy for users to follow.
| # 1. Define the kernel (matmul) and compile/lower it into an executable module | ||
| matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) | ||
|
|
||
| # 3. Test the kernel in Python with PyTorch data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # 4. Retrieve and inspect the generated CUDA source (optional) | ||
| cuda_source = jit_kernel.get_kernel_source() | ||
| print("Generated CUDA kernel:\n", cuda_source) | ||
| # cuda_source = jit_kernel.get_kernel_source() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable jit_kernel is no longer used in this example; it has been replaced by matmul_relu_kernel. This commented-out line should be updated to use the correct variable so that it works correctly if a user decides to uncomment it.
| # cuda_source = jit_kernel.get_kernel_source() | |
| # cuda_source = matmul_relu_kernel.get_kernel_source() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (9)
examples/quickstart.py (3)
40-42: Use a float literal for ReLU clamp to avoid implicit int→float cast.Replace
0with0.0for type clarity across dtypes.- for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0.0)
63-67: Generalize device selection so the example isn’t CUDA-only.Pick device dynamically; keeps the quickstart consistent with the PR’s goal.
-a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) +dev = "cuda" if torch.cuda.is_available() else "cpu" +a = torch.randn(M, K, device=dev, dtype=torch.float16) +b = torch.randn(K, N, device=dev, dtype=torch.float16) +c = torch.empty(M, N, device=dev, dtype=torch.float16)Note: If CPU support is intended, torch.float16 matmul may not be available on CPU; consider switching both the kernel
dtypeand torch tensors to float32 in that case.
68-70: Fix comment: this call executes the kernel, not the profiler.-# Run the kernel through the Profiler +# Execute the kernelREADME.md (6)
127-131: Keep auto-target, but show explicit target selection as a tip.Small copy tweak improves clarity for users targeting HIP/CPU explicitly.
-# @tilelang.jit(target="cuda") -# target currently can be "cuda" or "hip" or "cpu". -# if not specified, it will be inferred from the input tensors during compile time +## Tip: explicit target selection if needed +# @tilelang.jit(target="cuda") # or target="hip" or "cpu" +# If not specified, target is inferred from input tensors at compile time.
162-166: Use0.0in ReLU to avoid int→float cast and improve readability.- for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0.0)
187-190: Make device selection portable; avoid CUDA-only tensor creation.Keeps the quickstart aligned with “not just for CUDA.”
-a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) +dev = "cuda" if torch.cuda.is_available() else "cpu" +a = torch.randn(M, K, device=dev, dtype=torch.float16) +b = torch.randn(K, N, device=dev, dtype=torch.float16) +c = torch.empty(M, N, device=dev, dtype=torch.float16)If CPU is in scope, you may need to use float32 on CPU and pass
dtype="float32"intomatmul(...). Want a short snippet for that?
192-193: Comment nit: this line runs the kernel, not the profiler.-# Run the kernel through the Profiler +# Execute the kernel
206-208: Minor copyedit for section title.-# 5.Profile latency with kernel +# 5. Profile kernel latency
203-205: Update README snippet: replace stalejit_kernelvariable and CUDA-specific wording
The commented snippet at README.md:203-204 still usesjit_kernel/“CUDA”; update it to use the returned kernel variable and neutral wording.-# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) +# 4. Retrieve and inspect the generated device-specific source (optional) +# src = matmul_relu_kernel.get_kernel_source() +# print("Generated kernel source:\n", src)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
README.md(2 hunks)examples/quickstart.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/quickstart.py (3)
tilelang/jit/__init__.py (1)
jit(237-310)tilelang/jit/kernel.py (1)
get_profiler(360-376)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)
🔇 Additional comments (4)
examples/quickstart.py (3)
4-7: Auto-target JIT aligns with “not just CUDA.”Good change to let the backend be inferred from inputs; keeps the example portable.
11-15: Kernel entry rename improves clarity.Returning a clearly named prim_func is nicer than a generic
main.
47-47: Returning the prim_func object is the right UX.Enables direct
.get_profiler(...)and invocation without extra wrapping.README.md (1)
134-138: Kernel rename/readability LGTM.
…tile-ai#858) * Refactor matmul example to include ReLU activation and update batch size in benchmark script * lint fix
as title.
Summary by CodeRabbit
New Features
Refactor
Documentation