Skip to content

[Enhancement] Add support for vectorized loading in gemm_v1 on ROCm.#1331

Closed
Gongen-Ali wants to merge 1 commit intotile-ai:mainfrom
Gongen-Ali:main
Closed

[Enhancement] Add support for vectorized loading in gemm_v1 on ROCm.#1331
Gongen-Ali wants to merge 1 commit intotile-ai:mainfrom
Gongen-Ali:main

Conversation

@Gongen-Ali
Copy link
Collaborator

@Gongen-Ali Gongen-Ali commented Nov 25, 2025

Using vectorization to transfer data from shared memory to registers in gemm_v1 can achieve significant performance gains. Moreover, when computing fp8 GEMM with k_pack=2, gemm_v1 performs better than gemm_v2.
Mi308X:
M=N=K=16384, dtype=float8_e4m3fnuz
v1:
Best latency (s): 33.353660583496094 Best TFlops: 263.722 Best config: {'block_M': 128, 'block_N': 128, 'block_K': 64, 'num_stages': 1, 'thread_num': 256, 'policy': <GemmWarpPolicy.Square: 0>, 'enable_rasteration': False}
v2:
Best latency (s): 36.72024154663086 Best TFlops: 239.543 Best config: {'block_M': 128, 'block_N': 128, 'block_K': 128, 'num_stages': 1, 'thread_num': 256, 'policy': <GemmWarpPolicy.Square: 0>, 'enable_rasteration': True}

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 25, 2025

Walkthrough

Three HIP template files are modified to enhance vectorization and FP8 support: a new float32x2 vector type alias is introduced in common headers, GEMM kernels are refactored to use vectorized loads for non-transposed data paths, and a new make_fp8_e4_16_t device function is added to construct 16-wide FP8 vectors.

Changes

Cohort / File(s) Summary
Vector Type Aliases
src/tl_templates/hip/common.h
Adds float32x2 typedef as a 2-element float SIMD vector alias.
GEMM Vectorized Loading
src/tl_templates/hip/gemm.h
Refactors A and B data loading paths to use vectorized loads (float32x2 or float32x4) when not transposed; introduces kPack conditional branches and increments fetch loops by vector size; preserves existing transpose-aware logic.
FP8 Packing Helpers
src/tl_templates/hip/hip_fp8.h
Adds make_fp8_e4_16_t device function to construct 16-wide FP8 vectors from eight x and y values; follows existing fp8_e4_4_t and fp8_e4_8_t packing patterns.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • src/tl_templates/hip/gemm.h requires careful attention to verify correct vector-size indexing and that both transposed/non-transposed paths maintain correctness across kPack variants
  • src/tl_templates/hip/hip_fp8.h new function logic should be validated against the existing fp8_e4_8_t construction pattern for consistency
  • src/tl_templates/hip/common.h is straightforward but verify the vector size (2 × sizeof(float)) matches intended usage in other files

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 A float32x2 hops into view,
Vectorized loads make GEMM run true,
Eight x's, eight y's, now sixteen strong,
FP8 packing sings its song!
Efficiency bounds—hop, hop, hooray! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add support for vectorized loading in gemm_v1 on ROCm' directly and concisely describes the main change: introducing vectorized data loading for improved GEMM performance on ROCm.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)

131-165: Consider simplifying to match the CUDA implementation pattern.

The current implementation manually reinterprets and packs all 16 FP8 values. The CUDA version (shown in relevant snippets) achieves the same result more concisely by delegating to the existing make_fp8_e4_8_t function:

fp8_e4_16_t result;
result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
return result;

This refactor would improve maintainability and reduce code duplication.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b020685 and 010fd94.

📒 Files selected for processing (3)
  • src/tl_templates/hip/common.h (1 hunks)
  • src/tl_templates/hip/gemm.h (2 hunks)
  • src/tl_templates/hip/hip_fp8.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tl_templates/hip/hip_fp8.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
  • make_fp8_e4_16_t (113-122)
src/tl_templates/cuda/gemm_sp_sm80.h (1)
  • signed (128-130)
src/tl_templates/hip/gemm.h (1)
examples/gemm/example_gemm_intrinsics.py (1)
  • make_swizzle_layout (11-23)
🔇 Additional comments (4)
src/tl_templates/hip/common.h (1)

89-89: LGTM! Proper vector type addition.

The new float32x2 typedef follows the established pattern and enables the vectorized loads introduced in gemm.h.

src/tl_templates/hip/gemm.h (3)

185-205: Vectorization logic for A matrix is correct.

The conditional vectorization based on TransposeA is appropriate:

  • Non-transposed A (row-major [M,K]): accesses consecutive K elements → vectorizable
  • Transposed A ([K,M]): accesses strided elements → scalar loads retained

The implementation correctly uses float32x2 for kPack==1 and float32x4 for kPack>1.


211-231: Verify alignment guarantees for vectorized loads.

The vectorized loads use reinterpret_cast to float32x2* and float32x4*, which requires proper memory alignment. While the swizzled layout likely ensures alignment, explicitly verifying this would prevent potential misaligned access issues on hardware that enforces strict alignment.

Additionally, the type-punning via float32x* casts is a common GPU pattern for vectorized loads but relies on implementation-defined behavior. This works in practice for byte-copying but is worth documenting.


276-296: Same alignment verification needed in body_rs variant.

The B matrix loading in body_rs() uses the same vectorized load pattern as in body(). Ensure alignment guarantees apply here as well.

@LeiWang1999
Copy link
Member

Thanks @Gongen-Ali , we're good to go if we can fix the lint with ./format.sh.

It's also surprised that v2 is slower than v1, would you mind dig further and help us fix the performance issue?

@Gongen-Ali
Copy link
Collaborator Author

Gongen-Ali commented Nov 25, 2025

Thanks @Gongen-Ali , we're good to go if we can fix the lint with ./format.sh.

It's also surprised that v2 is slower than v1, would you mind dig further and help us fix the performance issue?

👌 We have already analyzed this issue and preliminarily concluded that, when k_pack=2, the vectorized transfer from ‘buf_dyn_shmem‘ to ’local‘ uses only fp8_e4_8_t, whereas fp8_e4_16_t could be used instead to improve transfer efficiency. We are currently testing this approach and plan to modify TVM's automatic vectorization handling accordingly.

@Gongen-Ali
Copy link
Collaborator Author

Gongen-Ali commented Nov 26, 2025

We have fix the v2 performance. And temporarily close the current PR.
#1344

@Gongen-Ali Gongen-Ali closed this Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants