Skip to content

[AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950#1743

Merged
LeiWang1999 merged 3 commits intotile-ai:mainfrom
hubertlu-tw:fix/gfx950-fp8-e4m3
Jan 29, 2026
Merged

[AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950#1743
LeiWang1999 merged 3 commits intotile-ai:mainfrom
hubertlu-tw:fix/gfx950-fp8-e4m3

Conversation

@hubertlu-tw
Copy link
Contributor

@hubertlu-tw hubertlu-tw commented Jan 27, 2026

Description

This PR fixes ROCm FP8 handling across gfx942/gfx950 by selecting the correct
FP8 variants at runtime and making MFMA/codegen recognize the FP8 dtypes used
by ROCm. It also consolidates FP8 selection into shared helpers so examples and
tests stay consistent across devices.

Key changes

  • Add determine_fp8_type() / determine_torch_fp8_type() helpers.- Route ROCm E5M2 through BF8 MFMA intrinsics and add missing MFMA dtype mappings.
  • Fix FP8 E4M3/E5M2 conversions and vector wrappers in HIP templates for gfx950.
  • Update FP8 examples to use shared selection logic and ROCm-friendly paths.
  • Make ROCm FP8 tilelibrary tests select per-GPU dtype instead of hardcoding FNUZ.

Tests

  • pytest -q testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
  • python /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
  • python /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
  • python /opt/tilelang/examples/gemm_fp8/regression_example_gemm_fp8.py

CC: @Gongen-Ali

Summary by CodeRabbit

  • New Features

    • Automatic FP8 dtype selection at runtime, including additional FP8/BF8 variants and platform-aware mappings.
    • Extended FP8 type wrappers and storage variants to improve HIP/AMD interoperability and codegen coverage.
  • Refactor

    • Examples and tooling now determine FP8 dtypes dynamically for consistent cross‑platform behavior.
  • Tests

    • GEMM tests updated to use runtime-selected FP8 dtypes for broader coverage and more accurate validation.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

Adds runtime FP8 dtype selection utilities and applies them across examples, tests, and kernel generation; refactors HIP FP8 C++ wrappers; extends MFMA intrinsic generator and HIP codegen to recognize additional FP8/BF8 variants.

Changes

Cohort / File(s) Summary
FP8 Runtime Selectors
tilelang/utils/target.py, tilelang/utils/__init__.py
Added determine_fp8_type and determine_torch_fp8_type and re-exported them to determine platform-appropriate FP8 dtype strings and PyTorch dtypes at runtime.
Examples — GEMM FP8 (runtime dtypes)
examples/gemm_fp8/example_tilelang_gemm_amd.py, examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py, examples/gemm_fp8/example_tilelang_gemm_fp8.py, examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py, examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Replaced hardcoded FP8 dtype constants with calls to determine_fp8_type / determine_torch_fp8_type; adjusted defaults (e.g., in_dtype=None) to pick dtypes at runtime; platform-aware benchmarking branches added (HIP vs non-HIP); intrinsic example switches emitter selection and tile sizing to emitter-driven values.
HIP FP8 Type System
src/tl_templates/hip/hip_fp8.h
Reworked FP8 typedefs into HIP-specific aliases and added explicit wrapper structs and aligned storage wrappers for E4/E5 FP8 representations with constructors/conversions for interoperability.
MFMA Intrinsic Generator
tilelang/intrinsics/mfma_macro_generator.py
Added recognition/abbreviations for new FP8 variant float8_e4m3fn (fp8) and BF8 paths; extended k-dim eligibility and MFMA suffix/prefix logic to handle additional FP8/BF8 variants.
HIP Codegen dtype map
src/target/codegen_hip.cc
Added MFMA dtype mappings for float8_e5m2x4 and float8_e5m2x8 to support new FP8 variants in codegen substitutions.
Tests updated for runtime dtypes
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
Replaced hardcoded FP8 dtype literals in test parametrizations with calls to determine_fp8_type(...) for platform-aware test dtype selection.

Sequence Diagram

sequenceDiagram
    participant App as Application / Example
    participant DetermineTorch as determine_torch_fp8_type()
    participant DetermineStr as determine_fp8_type()
    participant Platform as Platform Detector
    participant Torch as PyTorch

    App->>DetermineTorch: request torch.dtype for FP8
    DetermineTorch->>DetermineStr: request FP8 dtype string
    DetermineStr->>Platform: detect target (CUDA / ROCm / gfx)
    Platform-->>DetermineStr: platform info / gfx arch
    DetermineStr-->>DetermineTorch: return dtype string (e.g., "float8_e4m3fn" / "float8_e4m3fnuz")
    DetermineTorch->>Torch: map string -> torch.dtype
    Torch-->>DetermineTorch: return torch.dtype
    DetermineTorch-->>App: provide runtime torch.dtype for kernels/tests
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Gongen-Ali
  • LeiWang1999

Poem

🐇 I nibble bytes and sniff the air,
Picking FP8 with gentle care.
CUDA, ROCm, whichever fits,
I wrap, I map, the kernel flits.
Hops of code — a performant pair.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing ROCm FP8 dtype selection and MFMA support on specific AMD GPUs (gfx942/gfx950).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@tilelang/utils/__init__.py`:
- Line 3: Remove the unused "# noqa: F401" from the import line in
tilelang/utils/__init__.py so Ruff no longer flags the directive as unused;
locate the line that imports determine_target, select_fp8_e4m3_dtype, and
select_torch_fp8_e4m3_dtype and delete the trailing "# noqa: F401"
(alternatively, if you intended to silence F401, enable F401 in the Ruff config
instead).

In `@tilelang/utils/target.py`:
- Around line 73-79: The dtype-selection logic currently queries device 0 via
torch.cuda.get_device_properties(0); change it to use the active CUDA/HIP device
by calling torch.cuda.current_device() (or equivalent) and pass that index into
torch.cuda.get_device_properties so the gcn_arch check (gcnArchName) reflects
the currently selected GPU; update the block in tilelang/utils/target.py where
torch.version.hip, torch.cuda.is_available(), props =
torch.cuda.get_device_properties(0), and gcn_arch.startswith("gfx950") are used
to instead call torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)

67-79: Consider adding a float constructor for API symmetry.

fp8_e5_t lacks a constructor from float while fp8_e4_t provides one (lines 43-51). If this asymmetry is intentional for the current use cases, this is fine. Otherwise, consider adding it for API consistency:

♻️ Optional: Add float constructor to fp8_e5_t
 struct fp8_e5_t {
   unsigned char data;
   __device__ fp8_e5_t() {}
   __device__ fp8_e5_t(hip_fp8_e5_t val) {
     data = *reinterpret_cast<unsigned char *>(&val);
   }
+  __device__ fp8_e5_t(float val) {
+    data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
+  }
   __device__ operator hip_fp8_e5_t() const {
     return *reinterpret_cast<const hip_fp8_e5_t *>(&data);
   }
   __device__ operator float() const {
     return static_cast<float>(static_cast<hip_fp8_e5_t>(*this));
   }
 };

"""The profiler and convert to torch utils"""

from .target import determine_target # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove the unused # noqa: F401 to satisfy Ruff.

Ruff flags the directive as unused on this line, which can fail linting. Either drop it or enable F401 in the config.

🧹 Proposed fix
-from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype  # noqa: F401
+from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
🧰 Tools
🪛 Ruff (0.14.14)

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In `@tilelang/utils/__init__.py` at line 3, Remove the unused "# noqa: F401" from
the import line in tilelang/utils/__init__.py so Ruff no longer flags the
directive as unused; locate the line that imports determine_target,
select_fp8_e4m3_dtype, and select_torch_fp8_e4m3_dtype and delete the trailing
"# noqa: F401" (alternatively, if you intended to silence F401, enable F401 in
the Ruff config instead).

Comment on lines 73 to 79
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && find . -type f -name "target.py" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

find . -type f -name "target.py" -path "*/tilelang/utils/*"

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

find . -type f -name "*.py" | grep -i target | head -20

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

git ls-files | grep target.py

Repository: tile-ai/tilelang

Length of output: 84


🏁 Script executed:

cat -n tilelang/utils/target.py

Repository: tile-ai/tilelang

Length of output: 9539


🏁 Script executed:

rg "select_fp8_e4m3_dtype" --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 3304


🏁 Script executed:

rg "torch.cuda.get_device_properties" --type py -B 3 -A 3

Repository: tile-ai/tilelang

Length of output: 6868


🏁 Script executed:

web_search
PyTorch torch.cuda.current_device() torch.cuda.get_device_properties() multi-GPU best practices

Repository: tile-ai/tilelang

Length of output: 291


🌐 Web query:

PyTorch torch.cuda.current_device() best practices multi-GPU systems

💡 Result:

What torch.cuda.current_device() means (and why it can surprise you)

  • PyTorch tracks a per-process “currently selected” CUDA device; torch.cuda.current_device() returns that device index. By default it’s typically cuda:0 unless you change it. CUDA tensor allocation defaults to this device. [1][2]
  • If you restrict visibility with CUDA_VISIBLE_DEVICES, then cuda:0 is “GPU 0 among the visible set”, which may map to different physical GPUs in different processes—so seeing current_device()==0 can be totally correct in multi-process setups. [1][5]

Best practices on multi-GPU systems

1) Prefer explicit devices over relying on the “current device”

Use torch.device("cuda:N") (or a passed-in device) and move tensors/modules explicitly (.to(device)). This avoids accidental allocations on the wrong GPU when the current device changes. [1]

2) If you do rely on the current device, set it intentionally (per process)

  • Use torch.cuda.set_device(local_rank) early before creating CUDA tensors (or wrap code in with torch.cuda.device(local_rank):). [1][4]
  • This matters for distributed/NCCL: some collectives use torch.cuda.current_device() to decide which GPU to use, and PyTorch explicitly puts the responsibility on you to ensure each rank has the right GPU selected. [4]

3) With torchrun/DDP: map one process ↔ one GPU

  • torchrun provides LOCAL_RANK; use it to select the GPU and to configure DDP device_ids / output_device. [3]
  • Minimal pattern:
import os, torch
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = model.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
    model, device_ids=[local_rank], output_device=local_rank
)

[3][4]

4) Use CUDA_VISIBLE_DEVICES to control which GPUs PyTorch can see

This is often the cleanest way to partition GPUs across jobs/process groups, and it changes the meaning of device indices inside the process (re-indexing from 0). [1][5]

5) When creating new tensors inside forward, don’t depend on the current device

If you need a new tensor “like an existing tensor”, use x.new_* / new_tensor so it inherits device/dtype, rather than torch.* factory calls that can follow the current-device context. [1]

Sources: PyTorch CUDA semantics / best practices [1], torch.cuda API docs [2], torchrun docs [3], torch.distributed docs (NCCL note about current_device) [4], PyTorch forum guidance on current_device, set_device, and CUDA_VISIBLE_DEVICES behavior [5].


Use the current device when querying GPU architecture for dtype selection.

In multi-GPU ROCm/HIP systems, device 0 may not be the active device. When selecting the FP8 dtype based on GPU architecture (gfx950 vs. other), the function must query the current device instead of hardcoding device 0, otherwise it will return the incorrect dtype.

🔧 Proposed fix
-    props = torch.cuda.get_device_properties(0)
+    device = torch.cuda.current_device()
+    props = torch.cuda.get_device_properties(device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 73 - 79, The dtype-selection logic
currently queries device 0 via torch.cuda.get_device_properties(0); change it to
use the active CUDA/HIP device by calling torch.cuda.current_device() (or
equivalent) and pass that index into torch.cuda.get_device_properties so the
gcn_arch check (gcnArchName) reflects the currently selected GPU; update the
block in tilelang/utils/target.py where torch.version.hip,
torch.cuda.is_available(), props = torch.cuda.get_device_properties(0), and
gcn_arch.startswith("gfx950") are used to instead call
torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py`:
- Around line 237-240: The selected fp8 helpers (select_fp8_e4m3_dtype and
select_fp8_e5m2_dtype) return dtype names as strings but
tl_matmul/assert_tl_matmul_correctness expect TileLang dtype objects (e.g.,
T.float8_e4m3fn); update the calls so you convert the returned string to the
TileLang dtype object before passing to assert_tl_matmul_correctness (for
example by resolving the string via the TileLang type namespace, e.g., mapping
or using getattr(T, dtype_name) to get T.float8_e4m3fn), and use those resolved
dtype objects when invoking assert_tl_matmul_correctness and tl_matmul.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)

160-179: Add storage ctor/assignment for fp8_e5_4_t for API parity.

fp8_e4_4_t supports construction and assignment from its storage type (uint32_t); fp8_e5_4_t doesn't. If generated code attempts to use packed storage for E5, this API gap can cause compilation errors. Align fp8_e5_4_t with the fp8_e4_4_t interface by adding a storage constructor and assignment operator.

♻️ Suggested parity additions
 struct __align__(4) fp8_e5_4_t {
   union {
     fp8_e5_4_storage_t data;
     struct {
       fp8_e5_t x;
       fp8_e5_t y;
       fp8_e5_t z;
       fp8_e5_t w;
     };
   };
   __device__ fp8_e5_4_t() {}
+  __device__ fp8_e5_4_t(const fp8_e5_4_storage_t &val) : data(val) {}
   __device__ fp8_e5_4_t(const hip_fp8x4_e5_t &val) {
     data = *reinterpret_cast<const fp8_e5_4_storage_t *>(&val);
   }
   __device__ operator hip_fp8x4_e5_t() const {
     return *reinterpret_cast<const hip_fp8x4_e5_t *>(&data);
   }
+  __device__ fp8_e5_4_t &operator=(const fp8_e5_4_storage_t &val) {
+    data = val;
+    return *this;
+  }
 };

@hubertlu-tw hubertlu-tw changed the title [AMD] Fix gfx950 FP8 E4M3 selection in AMD FP8 examples [AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950 Jan 28, 2026
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution, I left a simple comment that would be better to rename the select_fp8_type into determine_fp8_type.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)

308-309: Inconsistent FP8 dtype selection: still using hardcoded FNUZ types.

While test_gemm_ss_fp8_rocm, test_gemm_sr_fp8_rocm, and test_gemm_rr_fp8_rocm use determine_fp8_type(), this test still hardcodes T.float8_e5m2fnuz and T.float8_e4m3fnuz. This inconsistency could cause failures on gfx950 where non-FNUZ types are preferred.

Consider updating these parametrizations to use determine_fp8_type() for consistency:

♻️ Proposed fix
 `@pytest.mark.parametrize`(
     "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
     [
-        (128, 128, 128, True, True, T.float8_e5m2fnuz, T.float8_e5m2fnuz, T.float32, 128, 128, 32, 2, 128),
-        (128, 128, 128, True, True, T.float8_e4m3fnuz, T.float8_e4m3fnuz, T.float32, 128, 128, 32, 2, 128),
+        (128, 128, 128, True, True, determine_fp8_type("e5m2"), determine_fp8_type("e5m2"), T.float32, 128, 128, 32, 2, 128),
+        (128, 128, 128, True, True, determine_fp8_type(), determine_fp8_type(), T.float32, 128, 128, 32, 2, 128),
     ],
 )
🤖 Fix all issues with AI agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py`:
- Around line 237-240: determine_fp8_type() returns a string but tl_matmul (and
assert_tl_matmul_correctness) expect a TileLang dtype object (e.g.,
T.float8_e4m3fn); update the calls that set e4m3_dtype and e5m2_dtype to convert
the returned string to the TileLang dtype object (use getattr on the T module
with the string), so the dtype passed into assert_tl_matmul_correctness and
ultimately tl_matmul matches the expected T.<dtype> objects.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your conrtibution!

@LeiWang1999 LeiWang1999 merged commit a55a823 into tile-ai:main Jan 29, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants