[Dtype] Improve host codegen handling for subtype#1517
[Dtype] Improve host codegen handling for subtype#1517LeiWang1999 merged 6 commits intotile-ai:mainfrom
Conversation
- Updated CUDA vectorized cast functions to ensure proper handling of float16, float32, bfloat16, and float8 conversions, adding checks for bit sizes. - Refactored dtype conversion logic in `cuda_fp4.h` to utilize `cudaRoundZero` for improved accuracy in floating-point conversions. - Introduced a new method in `KernelParam` to convert TVM DataType to TileLang dtype. - Adjusted argument binding logic in `arg_binder.cc` to allow for better subtype matching based on total bit counts. - Enhanced dtype handling in `dtypes.py` to accommodate new float4_e2m1fn types and ensure compatibility with PyTorch. This update aims to improve type safety and conversion accuracy across the codebase.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Warning Rate limit exceeded@LeiWang1999 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 12 minutes and 33 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR extends CUDA vectorized casting support by adding FP4 and FP8 conversion pathways, expands public FP4 conversion APIs across multiple numeric types, refactors argument binding to handle sub-byte data types with runtime shape assertions, and updates the dtype system to support new floating-point formats with fallback handling. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/target/codegen_cuda.cc (1)
744-765: Fix ordering inPrintVecElemStoreto avoidICHECKfailures forfloat4_e2m1fnwith 8 lanes
PrintVecElemLoadhandlest.is_float4_e2m1fn()before the generict.lanes() > 4 && t.lanes() <= 8case, so FP4 vectors use the dedicated nested-struct accessor logic.In
PrintVecElemStore, the order is reversed:} else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { ... } else if (t.bits() == 32) { ... } ICHECK(!type_name.empty()); ... } else if (t.is_float4_e2m1fn()) { // fp4_e2_64_t / 32_t / 16_t / 8_t / 4_t / 2_t path ... }For a
float4_e2m1fnvector withlanes == 8andbits == 4, thet.lanes() > 4 && t.lanes() <= 8branch matches first,type_nameremains empty (neither bits==16 nor 32), and theICHECK(!type_name.empty())fires. The FP4-specific path is never reached.Reorder the branches to mirror
PrintVecElemLoad:- } else if (t.lanes() > 4 && t.lanes() <= 8) { + } else if (t.is_float4_e2m1fn()) { + stream << vec; + if (t.lanes() >= 64) stream << "." << access[i / 32]; + if (t.lanes() >= 32) stream << "." << access[(i % 32) / 16]; + if (t.lanes() >= 16) stream << "." << access[(i % 16) / 8]; + if (t.lanes() >= 8) stream << "." << access[(i % 8) / 4]; + stream << "." << access[i % 4] << " = " << value << ";\n"; + } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; ...That way, all FP4 vector widths (including 8‑lane) take the intended nested-struct path, and the generic branch continues to serve 16- and 32-bit integer/float vectors only.
Also applies to: 826-862
src/transform/arg_binder.cc (1)
335-345: Restrictdata_is_subtypeto packed formats to preserve bool/int1/int4 compatibilityThe new subtype logic correctly handles packed formats like FP4 by validating total bit count instead of per-dimension types:
- You break shape binding for
buffer->dtype.bits() < 8(line 335-345)- Later,
data_is_subtype = buffer->dtype.bits() < 8(line 527) triggers total bits assertion instead of detailed dtype checkingHowever, this catches bool and 1/4-bit integer types as well. The issue:
- The bool handling block (lines ~525-545) builds a permissive
condacceptingint8/uint8/kDLBoolwith bits 1 or 8- But when
data_is_subtype = true, theif (!data_is_subtype)error path (line 558) is skipped- So
condis never consulted for error reporting—only the total bits assertion runs- If bool is encoded as int8 (
actual=8 bitsvsexpected=1 bit), the total bits check fails even though the permissive bool rules would accept itThis breaks existing interoperability: bool tensors previously accepted via int8 or kDLBool(bits=8) backing will now be rejected.
To preserve previous bool/int1/int4 behavior while keeping the FP4 optimization, exclude bool and known packed formats explicitly:
bool data_is_subtype = buffer->dtype.bits() < 8 && !buffer->dtype.is_bool(); // Optionally also restrict to known packed formatsor route bool back to the permissive block by keeping
!data_is_subtypetrue for bool, so the existingcondlogic applies.Also applies to: 525-559
tilelang/jit/adapter/tvm_ffi.py (1)
144-158: Avoid mutatingnative_shape[-1]for sub‑8‑bit dtypes at init time; move scaling into the runtime shape construction insteadTwo correctness issues with the new block:
native_shapecan be empty for scalar parameters (e.g.,KernelParam.from_varwithdtype.bits < 8such asbool), sonative_shape[-1]will raiseIndexErrorbefore the adapter is usable.- When the last dimension is symbolic (
tir.Var), multiplying/dividing it here converts it to a genericPrimExpr. Thefuncmethod at line 226 only runsdynamic_symbolic_mapresolution forisinstance(s, tir.Var), so this dimension will never be resolved to a concrete Python integer. You'll end up passing a TVM expression intotorch.empty(*shape, ...), causing a runtime error.The robust solution is:
- Keep
param_shapesin logical units (remove lines 153–157).- In
func, after buildingshapefor an output usingparam_shapes[i]+dynamic_symbolic_map(whenshapecontains Python ints), apply bit-ratio rescaling onshape[-1]for sub-8-bit dtypes with a guardif shape:.Example fix in
func:for i in range(len(self.params)): if i in self.result_idx: dtype = param_dtypes[i] shape = [] # ... existing shape resolution logic ... for s in param_shapes[i]: if isinstance(s, tir.Var): # ... dynamic resolution ... else: shape.append(s) + + tl_dtype = self.params[i].dtype + if getattr(tl_dtype, "bits", None) is not None and tl_dtype.bits < 8 and shape: + storage_dtype: dtype = dtype(self.params[i].torch_dtype()) + shape[-1] = ( + shape[-1] + * tl_dtype.bits + * tl_dtype.lanes + // (storage_dtype.bits * storage_dtype.lanes) + ) + if len(shape) == 0: # ... error handling ... tensor = torch.empty(*shape, dtype=dtype, device=out_device)Also fix the typo:
stroage_dtype→storage_dtypeon line 155.
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
2858-2931: Guard theramp_lanes = value_lanes / element_lanesadjustment for sub‑byte element typesThe new ramp handling for sub‑byte packed elements:
int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8) ? value_dtype.lanes() / element_dtype.lanes() : value_dtype.lanes(); if (arith::ramp(base, 1, ramp_lanes).Match(index)) { ... }and the mirror in
BufferStoreassume thatvalue_dtype.lanes()is an integer multiple ofelement_dtype.lanes()wheneverelement_dtype.bits() < 8andelement_dtype.lanes() > 1.If a future lowering accidentally produces a combination where that divisibility doesn’t hold, this division will silently truncate and the ramp pattern won’t match the actual indexing semantics.
To make failures more obvious and easier to diagnose, consider adding a defensive check in the sub‑byte branch, e.g.:
if (element_dtype.lanes() > 1 && element_dtype.bits() < 8) { ICHECK_EQ(value_dtype.lanes() % element_dtype.lanes(), 0) << "Unexpected lanes for packed sub-byte buffer load/store: value_dtype=" << value_dtype << ", element_dtype=" << element_dtype; ramp_lanes = value_dtype.lanes() / element_dtype.lanes(); } else { ramp_lanes = value_dtype.lanes(); }Same pattern can be applied in
BufferStore. This doesn’t change behavior in the valid cases, but will catch misconfigurations early.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/target/codegen_cuda.ccsrc/tl_templates/cuda/cuda_fp4.hsrc/transform/arg_binder.cctesting/python/language/test_tilelang_language_vectorized_cast.pytilelang/engine/param.pytilelang/jit/adapter/tvm_ffi.pytilelang/language/v2/dtypes.py
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/engine/param.py (1)
tilelang/language/v2/dtypes.py (1)
dtype(14-15)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)
tilelang/jit/adapter/tvm_ffi.py (3)
tilelang/language/v2/dtypes.py (1)
dtype(14-15)tilelang/language/v2/annot.py (2)
dtype(580-580)dtype(628-628)tilelang/engine/param.py (1)
torch_dtype(127-141)
tilelang/language/v2/dtypes.py (1)
tilelang/engine/param.py (1)
torch_dtype(127-141)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
float2(294-302)tilelang/language/v2/dtypes.py (1)
double(241-241)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
float4_e2m1fn(390-390)float16(299-299)float32(300-300)float64(301-301)bfloat16(397-397)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
tilelang/engine/param.py (1)
143-150: Accessortilelang_dtypelooks consistent and non-invasiveThis method cleanly mirrors
torch_dtype, centralizing theT.dtype(self.dtype)conversion and avoiding duplication at call sites. No behavioral concerns from this change.src/tl_templates/cuda/cuda_fp4.h (1)
157-273: FP4 conversion helpers are consistent with CUDA fp4/fp8 patternsThe new FP4 helper suite (
__tl_cvt_fp4*_...) is structured coherently:
- Uses the
__nv_cvt_*fp4 APIs from<cuda_fp4.h>withcudaRoundZero, then reinterprets raw types into__half/half2/float2/double2/__nv_bfloat16(__nv_bfloat162).- Mirrors the existing FP8 helpers’ style (storage-based
*_storage_toperands, vectorized*_x2variants, and round-zero semantics) so codegen can treat FP4 in the same way as FP8.I don’t see correctness issues here; just make sure your CUDA minimum version and
__CUDA_ARCH__ >= 800assumptions match the availability of these fp4 APIs in your supported toolchains.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/jit/adapter/tvm_ffi.py (1)
144-158: Guard sub-byte shape rewrite against empty shapes
native_shape[-1]is unconditionally accessed whentl_dtype.bits < 8. For any sub-byte param with an empty shape (e.g. a scalar bool/int1 or a 0‑D tensor), this will raise an IndexError during adapter construction.You can make this robust by guarding on
native_shapeand tightening the intent to true tensor params:Proposed adjustment
- tl_dtype = param.dtype - if tl_dtype.bits < 8: - stroage_dtype: dtype = dtype(param.torch_dtype()) - # last dim divide by bits to get the actual shape - native_shape[-1] = native_shape[-1] * tl_dtype.bits * tl_dtype.lanes // (stroage_dtype.bits * stroage_dtype.lanes) + tl_dtype = param.dtype + if tl_dtype.bits < 8 and native_shape: + storage_dtype: dtype = dtype(param.torch_dtype()) + # reinterpret last dim based on logical vs storage bit width + native_shape[-1] = ( + native_shape[-1] + * tl_dtype.bits * tl_dtype.lanes + // (storage_dtype.bits * storage_dtype.lanes) + )tilelang/language/v2/dtypes.py (1)
72-89: Fixfloat4_e2m1fnx2assertion to match actual PyTorch attribute nameThe assertion on line 195 checks for
torch.float4_e2m1fnx2(no underscore), but the actual PyTorch attribute istorch.float4_e2m1fn_x2(with underscore), as correctly used in the mapping at line 84. This causes the assertion to fail even when the dtype is available.Update the assertion to check for the correct attribute name:
elif dtype_str == "float4_e2m1fnx2": - assert hasattr(torch, "float4_e2m1fnx2"), ( - "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" + assert hasattr(torch, "float4_e2m1fn_x2"), ( + "torch.float4_e2m1fn_x2 is not supported in this version of torch. Please upgrade torch >= 2.8.0" ) return torch.float4_e2m1fn_x2src/transform/arg_binder.cc (1)
336-343: Exclude bool from the packed-subtype path to preserve int8/uint8 compatibilityThe new
data_is_subtypecheck treatsbuffer->dtype.bits() < 8as a packed format (e.g., FP4/INT4), but bool (1 bit) is unintentionally included. This bypasses the explicit bool↔int8/uint8 compatibility logic (lines 506-523) and enforces a total-bits equality that fails when a bool buffer receives an int8 DLTensor with matching shapes—a case documented as supported.Restrict the subtype path to exclude bool:
for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (buffer->dtype.bits() < 8) { + if (buffer->dtype.bits() < 8 && !buffer->dtype.is_bool()) { break; }and
- bool data_is_subtype = buffer->dtype.bits() < 8; + bool data_is_subtype = buffer->dtype.bits() < 8 && !buffer->dtype.is_bool();This preserves the documented bool↔int8/uint8 compatibility while restricting total-bit assertions to true packed formats.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
119-135: Broader FP4 test matrix is good; consider restoring standard main entrypointThe new FP4 cases in
test_vectorized_cast_fp4(half/float/double/bfloat16) nicely exercise the new CUDA paths. Two caveats:
run_vectorized_castreturns early wheneversrc_dtypeordst_dtypeisT.float4_e2m1fn, so for all these new cases you only assert on the presence of the vectorized intrinsic string and never on numerical correctness. If/when Torch FP4 support stabilizes, it would be worth re‑enabling the value checks for at least the FP4↔float32 path.- In the
__main__guard,tilelang.testing.main()is commented out in favor of a single directtest_vectorized_cast_fp4(...)call. That’s handy for local debugging but surprising in-tree; running this file as a script will no longer execute the full paramized test suite.You may want to restore
tilelang.testing.main()(and keep the direct call under a temporary or developer-only flag) so script runs match the pytest behavior.Also applies to: 137-139
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/target/codegen_cuda.ccsrc/tl_templates/cuda/cuda_fp4.hsrc/transform/arg_binder.cctesting/python/language/test_tilelang_language_vectorized_cast.pytilelang/engine/param.pytilelang/jit/adapter/tvm_ffi.pytilelang/language/v2/dtypes.py
🧰 Additional context used
🧬 Code graph analysis (5)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
float4_e2m1fn(390-390)float16(299-299)float32(300-300)float64(301-301)bfloat16(397-397)
tilelang/language/v2/dtypes.py (1)
tilelang/engine/param.py (1)
torch_dtype(127-141)
tilelang/jit/adapter/tvm_ffi.py (2)
tilelang/language/v2/dtypes.py (1)
dtype(14-15)tilelang/engine/param.py (1)
torch_dtype(127-141)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
IsCudaVectorizableFP8(137-140)IsCudaVectorizableFP8(137-137)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
float2(294-302)tilelang/language/v2/dtypes.py (1)
double(241-241)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (3)
tilelang/engine/param.py (1)
143-151: Newtilelang_dtypeaccessor is consistent and usefulExposing
KernelParam.tilelang_dtype()asT.dtype(self.dtype)mirrorstorch_dtype()and centralizes TVM→TileLang dtype conversion. This is a clean, low-risk addition.src/target/codegen_cuda.cc (1)
975-1050: Vectorized cast and sub-byte load/store changes look consistent with FP4/FP8 design
- Tightening the f16/bf16/FP8 vectorized cast conditions to explicitly require 32‑bit float sources/targets (
target_ty.bits() == 32/from_ty.bits() == 32) avoids accidentally applying these intrinsics to doubles or other float widths.- New FP4 paths (double↔float4_e2m1fn, bfloat16↔float4_e2m1fn) follow the established pattern:
- Use 2‑lane chunks via
PrintVectorizedCast.- Reinterpret FP4 vectors through
uint8_t*in line with storage-based FP4 APIs.- Delegate actual numeric work to the helpers defined in
cuda_fp4.h.- The adjusted
ramp_laneslogic inBufferLoad/BufferStoreensures vector load/store recognition still works when the buffer element itself is a packed sub-byte vector type (e.g. FP4 packed into int8), rather than assuming element lanes==value lanes.Overall the changes are coherent with the new FP4 storage model and vectorization strategy; the main follow-up is to ensure all used intrinsics exist with the expected signatures for the targeted CUDA/SM versions.
Also applies to: 1105-1150, 2858-2865, 2926-2933
src/tl_templates/cuda/cuda_fp4.h (1)
157-273: FP4 conversion helpers are symmetric and match the storage‑oriented designThe new FP4 helpers cover all expected pairs:
- Half/half2, float/float2, double/double2, and bfloat16/bfloat162, all expressed in terms of
__nv_fp4_storage_t/__nv_fp4x2_storage_t.- “To FP4” paths consistently use the CUDA
__nv_cvt_*_to_fp4[x2]intrinsics withcudaRoundZero.- “From FP4” paths either:
- Convert to half/half2 via
*_to_halfraw*then reinterpret raw→typed, or- Build float/double(/bfloat16) results via the corresponding half/float/bfloat intrinsics.
This lines up with how
codegen_cuda.cccalls these functions for vectorized casts and with the intended storage-dtype contract for sub-byte FP4 tensors.
as title. Thanks @Hamerlate
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.