[Enhancement] Add support for vectorized loading in gemm_v1 on ROCm.#1331
[Enhancement] Add support for vectorized loading in gemm_v1 on ROCm.#1331Gongen-Ali wants to merge 1 commit intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThree HIP template files are modified to enhance vectorization and FP8 support: a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)
131-165: Consider simplifying to match the CUDA implementation pattern.The current implementation manually reinterprets and packs all 16 FP8 values. The CUDA version (shown in relevant snippets) achieves the same result more concisely by delegating to the existing
make_fp8_e4_8_tfunction:fp8_e4_16_t result; result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7); result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7); return result;This refactor would improve maintainability and reduce code duplication.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/tl_templates/hip/common.h(1 hunks)src/tl_templates/hip/gemm.h(2 hunks)src/tl_templates/hip/hip_fp8.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tl_templates/hip/hip_fp8.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
make_fp8_e4_16_t(113-122)src/tl_templates/cuda/gemm_sp_sm80.h (1)
signed(128-130)
src/tl_templates/hip/gemm.h (1)
examples/gemm/example_gemm_intrinsics.py (1)
make_swizzle_layout(11-23)
🔇 Additional comments (4)
src/tl_templates/hip/common.h (1)
89-89: LGTM! Proper vector type addition.The new
float32x2typedef follows the established pattern and enables the vectorized loads introduced in gemm.h.src/tl_templates/hip/gemm.h (3)
185-205: Vectorization logic for A matrix is correct.The conditional vectorization based on
TransposeAis appropriate:
- Non-transposed A (row-major [M,K]): accesses consecutive K elements → vectorizable
- Transposed A ([K,M]): accesses strided elements → scalar loads retained
The implementation correctly uses
float32x2forkPack==1andfloat32x4forkPack>1.
211-231: Verify alignment guarantees for vectorized loads.The vectorized loads use reinterpret_cast to
float32x2*andfloat32x4*, which requires proper memory alignment. While the swizzled layout likely ensures alignment, explicitly verifying this would prevent potential misaligned access issues on hardware that enforces strict alignment.Additionally, the type-punning via
float32x*casts is a common GPU pattern for vectorized loads but relies on implementation-defined behavior. This works in practice for byte-copying but is worth documenting.
276-296: Same alignment verification needed in body_rs variant.The B matrix loading in
body_rs()uses the same vectorized load pattern as inbody(). Ensure alignment guarantees apply here as well.
|
Thanks @Gongen-Ali , we're good to go if we can fix the lint with It's also surprised that v2 is slower than v1, would you mind dig further and help us fix the performance issue? |
|
|
We have fix the v2 performance. And temporarily close the current PR. |
Using vectorization to transfer data from shared memory to registers in gemm_v1 can achieve significant performance gains. Moreover, when computing fp8 GEMM with k_pack=2, gemm_v1 performs better than gemm_v2.
Mi308X:
M=N=K=16384, dtype=float8_e4m3fnuz
v1:
Best latency (s): 33.353660583496094 Best TFlops: 263.722 Best config: {'block_M': 128, 'block_N': 128, 'block_K': 64, 'num_stages': 1, 'thread_num': 256, 'policy': <GemmWarpPolicy.Square: 0>, 'enable_rasteration': False}v2:
Best latency (s): 36.72024154663086 Best TFlops: 239.543 Best config: {'block_M': 128, 'block_N': 128, 'block_K': 128, 'num_stages': 1, 'thread_num': 256, 'policy': <GemmWarpPolicy.Square: 0>, 'enable_rasteration': True}