-
Notifications
You must be signed in to change notification settings - Fork 450
[Language] Enhance T.dtype.as_torch conversion for compatibility #1473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added support for new float8 and float4 data types in the __dtype_as_torch__ method. - Implemented backend-specific handling for float8_e4m3 based on HIP or CUDA. - Included assertions to ensure compatibility with the required PyTorch versions for each dtype. - Improved error handling for unsupported dtypes.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughRefactors Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/v2/dtypes.py (1)
193-196: Consider updating error message to include all supported dtypes.The error message lists
_STR_TO_TORCH_DTYPE.keys()as supported dtypes, but the specially-handled variants (float8_e4m3,e4m3fnuz_float8,float8_e8m0fnu,float4_e2m1fnx2) are not in that dictionary. Users receiving this error won't see the full list of convertible dtypes.🔎 Proposed fix
+_SPECIAL_TORCH_DTYPES = {"float8_e4m3", "float8_e5m2", "e4m3fnuz_float8", "float8_e8m0fnu", "float4_e2m1fnx2"} + def __dtype_as_torch__(self: dtype) -> torch.dtype: """Convert TileLang dtype to PyTorch dtype.""" dtype_str = str(self) # ... existing if-elif chain ... elif dtype_str in _STR_TO_TORCH_DTYPE: return _STR_TO_TORCH_DTYPE[dtype_str] - raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}") + all_supported = list(_STR_TO_TORCH_DTYPE.keys()) + list(_SPECIAL_TORCH_DTYPES) + raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {all_supported}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/v2/dtypes.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (2)
tilelang/language/v2/dtypes.py (2)
161-174: LGTM! Backend-specific handling is well-implemented.The HIP vs CUDA detection using
torch.version.hipis the correct approach, and the assertions provide clear upgrade guidance for users with older PyTorch versions.
178-182: Verify thee4m3fnuz_float8dtype string origin.The dtype string
e4m3fnuz_float8uses a reversed naming convention compared to other float8 variants (e.g.,float8_e4m3fnuz). This string doesn't appear in_all_dtypesor the TYPE_CHECKING class definitions.Is this an intentional alias for compatibility with an external system or a legacy naming convention? Consider adding a brief comment explaining its origin.
|
LTCP, Merged :) |
…ions - Commented out the main execution call in the test script and replaced it with a direct call to the test function `test_divmod()`. - Enhanced error messages in the dtype conversion assertions to improve clarity and readability, ensuring proper guidance for required PyTorch versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tilelang/language/v2/dtypes.py (1)
183-192: Verify PyTorch version requirements and dtype naming.Per the previous review, the version requirement and dtype naming may need updating:
- OCP Micro-scaling Format support may require
torch >= 2.9.0instead of>= 2.8.0- The PyTorch dtype name may be
float4_e2m1fn_x2(with underscore before x2) rather thanfloat4_e2m1fnx2PyTorch float4_e2m1fn dtype version support
🧹 Nitpick comments (2)
tilelang/language/v2/dtypes.py (2)
178-182: Clarify the non-standard dtype string format.The string
"e4m3fnuz_float8"uses a reversed naming pattern compared to all other dtype strings (e.g.,"float8_e4m3fnuz"). Is this intentional to support an external format from TVM or another upstream source?If so, consider adding a brief comment explaining this alias for future maintainers.
196-196: Error message doesn't list all supported dtypes.The extended dtypes handled specially above (e.g.,
float8_e4m3,float8_e5m2,float8_e8m0fnu,float4_e2m1fnx2) are not included in_STR_TO_TORCH_DTYPE, so they won't appear in the error message. This may confuse users when they pass a valid extended dtype but have an older PyTorch version.Consider either:
- Distinguishing between "unsupported dtype" and "dtype requires newer PyTorch", or
- Including the extended dtype strings in the error message.
🔎 Proposed fix
- raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}") + extended_dtypes = ["float8_e4m3", "float8_e5m2", "e4m3fnuz_float8", "float8_e8m0fnu", "float4_e2m1fnx2"] + all_supported = list(_STR_TO_TORCH_DTYPE.keys()) + extended_dtypes + raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {all_supported}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/v2/dtypes.py(1 hunks)
🔇 Additional comments (2)
tilelang/language/v2/dtypes.py (2)
161-174: LGTM!The backend-aware handling for
float8_e4m3is well-implemented. The HIP vs CUDA detection usingtorch.version.hipis the correct approach, and the version assertions provide helpful guidance for users.
175-177: LGTM!The handling for
float8_e5m2is correct and consistent with the pattern established above.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.