Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 19, 2025

  • Added support for new float8 and float4 data types in the dtype_as_torch method.
  • Implemented backend-specific handling for float8_e4m3 based on HIP or CUDA.
  • Included assertions to ensure compatibility with the required PyTorch versions for each dtype.
  • Improved error handling for unsupported dtypes.

Summary by CodeRabbit

  • New Features
    • Added support for more extended float8/float4 numeric types (multiple variants) to broaden precision options.
    • Implemented backend-aware handling so AMD ROCm and CUDA platforms select appropriate runtime dtypes.
    • Added runtime checks with clear upgrade hints when a needed dtype is unavailable.

✏️ Tip: You can customize this high-level summary in your review settings.

- Added support for new float8 and float4 data types in the __dtype_as_torch__ method.
- Implemented backend-specific handling for float8_e4m3 based on HIP or CUDA.
- Included assertions to ensure compatibility with the required PyTorch versions for each dtype.
- Improved error handling for unsupported dtypes.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 19, 2025

Walkthrough

Refactors __dtype_as_torch__ to replace a simple dict lookup with backend-aware conditional branches that map extended float8/float4 dtype strings to specific torch dtypes, add HIP vs CUDA selection, and assert availability with versioned hints.

Changes

Cohort / File(s) Summary
Dtype conversion refactoring
tilelang/language/v2/dtypes.py
Replace string-membership check with multi-branch dispatch in __dtype_as_torch__; add HIP vs CUDA selection for float8_e4m3 variants, branches for float8_e5m2, e4m3fnuz_float8, float8_e8m0fnu, float4_e2m1fnx2; add runtime assertions for torch-dtype existence and preserve fallback to _STR_TO_TORCH_DTYPE for other dtypes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Review HIP vs CUDA branching and platform detection logic.
  • Verify correctness and clarity of runtime assertions and error/version hints.
  • Confirm coverage of additional float8/float4 branches and fallback mapping behavior.

Possibly related PRs

Poem

🐰
I hopped through code with nimble paws,
Chose CUDA left and HIP because,
I matched each float8, one by one,
Ensured the torch types shone like sun,
A rabbit's tweak — concise and fun! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Language] Enhance T.dtype.as_torch conversion for compatibility' directly reflects the main change: enhancing the dtype conversion method with support for new float8/float4 types and backend-specific handling.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/language/v2/dtypes.py (1)

193-196: Consider updating error message to include all supported dtypes.

The error message lists _STR_TO_TORCH_DTYPE.keys() as supported dtypes, but the specially-handled variants (float8_e4m3, e4m3fnuz_float8, float8_e8m0fnu, float4_e2m1fnx2) are not in that dictionary. Users receiving this error won't see the full list of convertible dtypes.

🔎 Proposed fix
+_SPECIAL_TORCH_DTYPES = {"float8_e4m3", "float8_e5m2", "e4m3fnuz_float8", "float8_e8m0fnu", "float4_e2m1fnx2"}
+
 def __dtype_as_torch__(self: dtype) -> torch.dtype:
     """Convert TileLang dtype to PyTorch dtype."""
     dtype_str = str(self)
     # ... existing if-elif chain ...
     elif dtype_str in _STR_TO_TORCH_DTYPE:
         return _STR_TO_TORCH_DTYPE[dtype_str]

-    raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")
+    all_supported = list(_STR_TO_TORCH_DTYPE.keys()) + list(_SPECIAL_TORCH_DTYPES)
+    raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {all_supported}")
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 95e3b5a and e428879.

📒 Files selected for processing (1)
  • tilelang/language/v2/dtypes.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (2)
tilelang/language/v2/dtypes.py (2)

161-174: LGTM! Backend-specific handling is well-implemented.

The HIP vs CUDA detection using torch.version.hip is the correct approach, and the assertions provide clear upgrade guidance for users with older PyTorch versions.


178-182: Verify the e4m3fnuz_float8 dtype string origin.

The dtype string e4m3fnuz_float8 uses a reversed naming convention compared to other float8 variants (e.g., float8_e4m3fnuz). This string doesn't appear in _all_dtypes or the TYPE_CHECKING class definitions.

Is this an intentional alias for compatibility with an external system or a legacy naming convention? Consider adding a brief comment explaining its origin.

@LeiWang1999
Copy link
Member Author

LTCP, Merged :)

…ions

- Commented out the main execution call in the test script and replaced it with a direct call to the test function `test_divmod()`.
- Enhanced error messages in the dtype conversion assertions to improve clarity and readability, ensuring proper guidance for required PyTorch versions.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tilelang/language/v2/dtypes.py (1)

183-192: Verify PyTorch version requirements and dtype naming.

Per the previous review, the version requirement and dtype naming may need updating:

  • OCP Micro-scaling Format support may require torch >= 2.9.0 instead of >= 2.8.0
  • The PyTorch dtype name may be float4_e2m1fn_x2 (with underscore before x2) rather than float4_e2m1fnx2
PyTorch float4_e2m1fn dtype version support
🧹 Nitpick comments (2)
tilelang/language/v2/dtypes.py (2)

178-182: Clarify the non-standard dtype string format.

The string "e4m3fnuz_float8" uses a reversed naming pattern compared to all other dtype strings (e.g., "float8_e4m3fnuz"). Is this intentional to support an external format from TVM or another upstream source?

If so, consider adding a brief comment explaining this alias for future maintainers.


196-196: Error message doesn't list all supported dtypes.

The extended dtypes handled specially above (e.g., float8_e4m3, float8_e5m2, float8_e8m0fnu, float4_e2m1fnx2) are not included in _STR_TO_TORCH_DTYPE, so they won't appear in the error message. This may confuse users when they pass a valid extended dtype but have an older PyTorch version.

Consider either:

  1. Distinguishing between "unsupported dtype" and "dtype requires newer PyTorch", or
  2. Including the extended dtype strings in the error message.
🔎 Proposed fix
-    raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")
+    extended_dtypes = ["float8_e4m3", "float8_e5m2", "e4m3fnuz_float8", "float8_e8m0fnu", "float4_e2m1fnx2"]
+    all_supported = list(_STR_TO_TORCH_DTYPE.keys()) + extended_dtypes
+    raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {all_supported}")
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e428879 and e0c7fd3.

📒 Files selected for processing (1)
  • tilelang/language/v2/dtypes.py (1 hunks)
🔇 Additional comments (2)
tilelang/language/v2/dtypes.py (2)

161-174: LGTM!

The backend-aware handling for float8_e4m3 is well-implemented. The HIP vs CUDA detection using torch.version.hip is the correct approach, and the version assertions provide helpful guidance for users.


175-177: LGTM!

The handling for float8_e5m2 is correct and consistent with the pattern established above.

@LeiWang1999 LeiWang1999 merged commit 3516f1e into tile-ai:main Dec 19, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant