Skip to content

[Dtype] Improve host codegen handling for subtype#1517

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:fp4_1223
Dec 24, 2025
Merged

[Dtype] Improve host codegen handling for subtype#1517
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:fp4_1223

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 24, 2025

as title. Thanks @Hamerlate

Summary by CodeRabbit

  • New Features
    • Added FP4 data type conversions across multiple formats (Half, Float, Double, BFloat16)
    • Extended vectorized cast operations for FP4 and FP8 data types
    • Improved sub-byte data type handling with enhanced shape and stride management
    • Better precision type compatibility with refined bit-width operations

✏️ Tip: You can customize this high-level summary in your review settings.

LeiWang1999 and others added 2 commits December 23, 2025 17:29
- Updated CUDA vectorized cast functions to ensure proper handling of float16, float32, bfloat16, and float8 conversions, adding checks for bit sizes.
- Refactored dtype conversion logic in `cuda_fp4.h` to utilize `cudaRoundZero` for improved accuracy in floating-point conversions.
- Introduced a new method in `KernelParam` to convert TVM DataType to TileLang dtype.
- Adjusted argument binding logic in `arg_binder.cc` to allow for better subtype matching based on total bit counts.
- Enhanced dtype handling in `dtypes.py` to accommodate new float4_e2m1fn types and ensure compatibility with PyTorch.

This update aims to improve type safety and conversion accuracy across the codebase.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

Warning

Rate limit exceeded

@LeiWang1999 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 12 minutes and 33 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 49bc41d and d25be32.

📒 Files selected for processing (2)
  • examples/flash_decoding/example_gqa_decode_varlen_logits.py
  • testing/python/language/test_tilelang_language_vectorized_cast.py
📝 Walkthrough

Walkthrough

This PR extends CUDA vectorized casting support by adding FP4 and FP8 conversion pathways, expands public FP4 conversion APIs across multiple numeric types, refactors argument binding to handle sub-byte data types with runtime shape assertions, and updates the dtype system to support new floating-point formats with fallback handling.

Changes

Cohort / File(s) Summary
CUDA Vectorized Cast Codegen
src/target/codegen_cuda.cc
Adds conditional guards on 32-bit source/target for existing casts; introduces new vectorized paths for double2↔fp4x2, bfloat162↔fp4x2, float2↔fp8x2, and cross-type bfloat16↔float4 conversions with type suffix handling; retains elementwise fallback.
FP4 Conversion APIs
src/tl_templates/cuda/cuda_fp4.h
Introduces 16 new public conversion methods using storage-typed interfaces for FP4↔half, FP4↔float, FP4↔double, FP4↔bfloat16 (single and x2 variants); replaces older implementations with storage-based wrappers.
Runtime Type Binding & Argument Resolution
src/transform/arg_binder.cc
Replaces exact dtype matches with generalized subtype check (dtype.bits() < 8); adds data_is_subtype flag and runtime total-bit consistency assertions using shape buffers; extends symbolic shape resolution with cascaded if-then-else guards and NULL pointer safety.
Sub-byte Data Type Shape Adaptation
tilelang/jit/adapter/tvm_ffi.py
Adds shape reinterpretation logic for sub-byte dtypes: adjusts final dimension using storage dtype-sized scaling to recompute native shape for packed types.
Kernel Parameter Type Conversion
tilelang/engine/param.py
Adds new public method tilelang_dtype() to convert TVM DataType to TileLang dtype via T.dtype().
Dtype System Compatibility
tilelang/language/v2/dtypes.py
Adds guarded mapping for extended Torch dtypes (float4_e2m1fn_x2); provides runtime fallback to torch.int8 when float4 variants unavailable.
Vectorized Cast Test Coverage
testing/python/language/test_tilelang_language_vectorized_cast.py
Expands parameterized tests for FP4↔half, FP4↔float, FP4↔double, FP4↔bfloat16; modifies test runner invocation to directly call FP4-to-float test.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

🐰 Hops through circuits, FP4 in sight,
Storage types aligned just right,
Shape assertions caught at runtime's call,
Vectorized paths for one and all!

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Dtype] Improve host codegen handling for subtype' accurately summarizes the main focus of the pull request, which centers on improving how host code generation handles dtype subtypes.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/target/codegen_cuda.cc (1)

744-765: Fix ordering in PrintVecElemStore to avoid ICHECK failures for float4_e2m1fn with 8 lanes

PrintVecElemLoad handles t.is_float4_e2m1fn() before the generic t.lanes() > 4 && t.lanes() <= 8 case, so FP4 vectors use the dedicated nested-struct accessor logic.

In PrintVecElemStore, the order is reversed:

} else if (t.lanes() > 4 && t.lanes() <= 8) {
  std::string type_name;
  if (t.bits() == 16) { ... }
  else if (t.bits() == 32) { ... }
  ICHECK(!type_name.empty());
  ...
} else if (t.is_float4_e2m1fn()) {
  // fp4_e2_64_t / 32_t / 16_t / 8_t / 4_t / 2_t path
  ...
}

For a float4_e2m1fn vector with lanes == 8 and bits == 4, the t.lanes() > 4 && t.lanes() <= 8 branch matches first, type_name remains empty (neither bits==16 nor 32), and the ICHECK(!type_name.empty()) fires. The FP4-specific path is never reached.

Reorder the branches to mirror PrintVecElemLoad:

-  } else if (t.lanes() > 4 && t.lanes() <= 8) {
+  } else if (t.is_float4_e2m1fn()) {
+    stream << vec;
+    if (t.lanes() >= 64) stream << "." << access[i / 32];
+    if (t.lanes() >= 32) stream << "." << access[(i % 32) / 16];
+    if (t.lanes() >= 16) stream << "." << access[(i % 16) / 8];
+    if (t.lanes() >= 8)  stream << "." << access[(i % 8) / 4];
+    stream << "." << access[i % 4] << " = " << value << ";\n";
+  } else if (t.lanes() > 4 && t.lanes() <= 8) {
     std::string type_name;
     ...

That way, all FP4 vector widths (including 8‑lane) take the intended nested-struct path, and the generic branch continues to serve 16- and 32-bit integer/float vectors only.

Also applies to: 826-862

src/transform/arg_binder.cc (1)

335-345: Restrict data_is_subtype to packed formats to preserve bool/int1/int4 compatibility

The new subtype logic correctly handles packed formats like FP4 by validating total bit count instead of per-dimension types:

  • You break shape binding for buffer->dtype.bits() < 8 (line 335-345)
  • Later, data_is_subtype = buffer->dtype.bits() < 8 (line 527) triggers total bits assertion instead of detailed dtype checking

However, this catches bool and 1/4-bit integer types as well. The issue:

  • The bool handling block (lines ~525-545) builds a permissive cond accepting int8/uint8/kDLBool with bits 1 or 8
  • But when data_is_subtype = true, the if (!data_is_subtype) error path (line 558) is skipped
  • So cond is never consulted for error reporting—only the total bits assertion runs
  • If bool is encoded as int8 (actual=8 bits vs expected=1 bit), the total bits check fails even though the permissive bool rules would accept it

This breaks existing interoperability: bool tensors previously accepted via int8 or kDLBool(bits=8) backing will now be rejected.

To preserve previous bool/int1/int4 behavior while keeping the FP4 optimization, exclude bool and known packed formats explicitly:

bool data_is_subtype =
    buffer->dtype.bits() < 8 &&
    !buffer->dtype.is_bool();
    // Optionally also restrict to known packed formats

or route bool back to the permissive block by keeping !data_is_subtype true for bool, so the existing cond logic applies.

Also applies to: 525-559

tilelang/jit/adapter/tvm_ffi.py (1)

144-158: Avoid mutating native_shape[-1] for sub‑8‑bit dtypes at init time; move scaling into the runtime shape construction instead

Two correctness issues with the new block:

  • native_shape can be empty for scalar parameters (e.g., KernelParam.from_var with dtype.bits < 8 such as bool), so native_shape[-1] will raise IndexError before the adapter is usable.
  • When the last dimension is symbolic (tir.Var), multiplying/dividing it here converts it to a generic PrimExpr. The func method at line 226 only runs dynamic_symbolic_map resolution for isinstance(s, tir.Var), so this dimension will never be resolved to a concrete Python integer. You'll end up passing a TVM expression into torch.empty(*shape, ...), causing a runtime error.

The robust solution is:

  • Keep param_shapes in logical units (remove lines 153–157).
  • In func, after building shape for an output using param_shapes[i] + dynamic_symbolic_map (when shape contains Python ints), apply bit-ratio rescaling on shape[-1] for sub-8-bit dtypes with a guard if shape:.

Example fix in func:

                 for i in range(len(self.params)):
                     if i in self.result_idx:
                         dtype = param_dtypes[i]
                         shape = []
                         # ... existing shape resolution logic ...
                         for s in param_shapes[i]:
                             if isinstance(s, tir.Var):
                                 # ... dynamic resolution ...
                             else:
                                 shape.append(s)
+
+                        tl_dtype = self.params[i].dtype
+                        if getattr(tl_dtype, "bits", None) is not None and tl_dtype.bits < 8 and shape:
+                            storage_dtype: dtype = dtype(self.params[i].torch_dtype())
+                            shape[-1] = (
+                                shape[-1]
+                                * tl_dtype.bits
+                                * tl_dtype.lanes
+                                // (storage_dtype.bits * storage_dtype.lanes)
+                            )
+
                         if len(shape) == 0:
                             # ... error handling ...
                         tensor = torch.empty(*shape, dtype=dtype, device=out_device)

Also fix the typo: stroage_dtypestorage_dtype on line 155.

🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)

2858-2931: Guard the ramp_lanes = value_lanes / element_lanes adjustment for sub‑byte element types

The new ramp handling for sub‑byte packed elements:

int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8)
                     ? value_dtype.lanes() / element_dtype.lanes()
                     : value_dtype.lanes();
if (arith::ramp(base, 1, ramp_lanes).Match(index)) {
  ...
}

and the mirror in BufferStore assume that value_dtype.lanes() is an integer multiple of element_dtype.lanes() whenever element_dtype.bits() < 8 and element_dtype.lanes() > 1.

If a future lowering accidentally produces a combination where that divisibility doesn’t hold, this division will silently truncate and the ramp pattern won’t match the actual indexing semantics.

To make failures more obvious and easier to diagnose, consider adding a defensive check in the sub‑byte branch, e.g.:

if (element_dtype.lanes() > 1 && element_dtype.bits() < 8) {
  ICHECK_EQ(value_dtype.lanes() % element_dtype.lanes(), 0)
      << "Unexpected lanes for packed sub-byte buffer load/store: value_dtype="
      << value_dtype << ", element_dtype=" << element_dtype;
  ramp_lanes = value_dtype.lanes() / element_dtype.lanes();
} else {
  ramp_lanes = value_dtype.lanes();
}

Same pattern can be applied in BufferStore. This doesn’t change behavior in the valid cases, but will catch misconfigurations early.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11f122e and 6a778a5.

📒 Files selected for processing (7)
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • src/transform/arg_binder.cc
  • testing/python/language/test_tilelang_language_vectorized_cast.py
  • tilelang/engine/param.py
  • tilelang/jit/adapter/tvm_ffi.py
  • tilelang/language/v2/dtypes.py
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/engine/param.py (1)
tilelang/language/v2/dtypes.py (1)
  • dtype (14-15)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
tilelang/jit/adapter/tvm_ffi.py (3)
tilelang/language/v2/dtypes.py (1)
  • dtype (14-15)
tilelang/language/v2/annot.py (2)
  • dtype (580-580)
  • dtype (628-628)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
tilelang/language/v2/dtypes.py (1)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
  • float2 (294-302)
tilelang/language/v2/dtypes.py (1)
  • double (241-241)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
  • float4_e2m1fn (390-390)
  • float16 (299-299)
  • float32 (300-300)
  • float64 (301-301)
  • bfloat16 (397-397)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
tilelang/engine/param.py (1)

143-150: Accessor tilelang_dtype looks consistent and non-invasive

This method cleanly mirrors torch_dtype, centralizing the T.dtype(self.dtype) conversion and avoiding duplication at call sites. No behavioral concerns from this change.

src/tl_templates/cuda/cuda_fp4.h (1)

157-273: FP4 conversion helpers are consistent with CUDA fp4/fp8 patterns

The new FP4 helper suite (__tl_cvt_fp4*_...) is structured coherently:

  • Uses the __nv_cvt_* fp4 APIs from <cuda_fp4.h> with cudaRoundZero, then reinterprets raw types into __half / half2 / float2 / double2 / __nv_bfloat16(__nv_bfloat162).
  • Mirrors the existing FP8 helpers’ style (storage-based *_storage_t operands, vectorized *_x2 variants, and round-zero semantics) so codegen can treat FP4 in the same way as FP8.

I don’t see correctness issues here; just make sure your CUDA minimum version and __CUDA_ARCH__ >= 800 assumptions match the availability of these fp4 APIs in your supported toolchains.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/jit/adapter/tvm_ffi.py (1)

144-158: Guard sub-byte shape rewrite against empty shapes

native_shape[-1] is unconditionally accessed when tl_dtype.bits < 8. For any sub-byte param with an empty shape (e.g. a scalar bool/int1 or a 0‑D tensor), this will raise an IndexError during adapter construction.

You can make this robust by guarding on native_shape and tightening the intent to true tensor params:

Proposed adjustment
-            tl_dtype = param.dtype
-            if tl_dtype.bits < 8:
-                stroage_dtype: dtype = dtype(param.torch_dtype())
-                # last dim divide by bits to get the actual shape
-                native_shape[-1] = native_shape[-1] * tl_dtype.bits * tl_dtype.lanes // (stroage_dtype.bits * stroage_dtype.lanes)
+            tl_dtype = param.dtype
+            if tl_dtype.bits < 8 and native_shape:
+                storage_dtype: dtype = dtype(param.torch_dtype())
+                # reinterpret last dim based on logical vs storage bit width
+                native_shape[-1] = (
+                    native_shape[-1]
+                    * tl_dtype.bits * tl_dtype.lanes
+                    // (storage_dtype.bits * storage_dtype.lanes)
+                )
tilelang/language/v2/dtypes.py (1)

72-89: Fix float4_e2m1fnx2 assertion to match actual PyTorch attribute name

The assertion on line 195 checks for torch.float4_e2m1fnx2 (no underscore), but the actual PyTorch attribute is torch.float4_e2m1fn_x2 (with underscore), as correctly used in the mapping at line 84. This causes the assertion to fail even when the dtype is available.

Update the assertion to check for the correct attribute name:

 elif dtype_str == "float4_e2m1fnx2":
-    assert hasattr(torch, "float4_e2m1fnx2"), (
-        "torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
+    assert hasattr(torch, "float4_e2m1fn_x2"), (
+        "torch.float4_e2m1fn_x2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
     )
     return torch.float4_e2m1fn_x2
src/transform/arg_binder.cc (1)

336-343: Exclude bool from the packed-subtype path to preserve int8/uint8 compatibility

The new data_is_subtype check treats buffer->dtype.bits() < 8 as a packed format (e.g., FP4/INT4), but bool (1 bit) is unintentionally included. This bypasses the explicit bool↔int8/uint8 compatibility logic (lines 506-523) and enforces a total-bits equality that fails when a bool buffer receives an int8 DLTensor with matching shapes—a case documented as supported.

Restrict the subtype path to exclude bool:

    for (size_t k = 0; k < buffer->shape.size(); ++k) {
-      if (buffer->dtype.bits() < 8) {
+      if (buffer->dtype.bits() < 8 && !buffer->dtype.is_bool()) {
         break;
       }

and

-    bool data_is_subtype = buffer->dtype.bits() < 8;
+    bool data_is_subtype = buffer->dtype.bits() < 8 && !buffer->dtype.is_bool();

This preserves the documented bool↔int8/uint8 compatibility while restricting total-bit assertions to true packed formats.

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)

119-135: Broader FP4 test matrix is good; consider restoring standard main entrypoint

The new FP4 cases in test_vectorized_cast_fp4 (half/float/double/bfloat16) nicely exercise the new CUDA paths. Two caveats:

  • run_vectorized_cast returns early whenever src_dtype or dst_dtype is T.float4_e2m1fn, so for all these new cases you only assert on the presence of the vectorized intrinsic string and never on numerical correctness. If/when Torch FP4 support stabilizes, it would be worth re‑enabling the value checks for at least the FP4↔float32 path.
  • In the __main__ guard, tilelang.testing.main() is commented out in favor of a single direct test_vectorized_cast_fp4(...) call. That’s handy for local debugging but surprising in-tree; running this file as a script will no longer execute the full paramized test suite.

You may want to restore tilelang.testing.main() (and keep the direct call under a temporary or developer-only flag) so script runs match the pytest behavior.

Also applies to: 137-139

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11f122e and 49bc41d.

📒 Files selected for processing (7)
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • src/transform/arg_binder.cc
  • testing/python/language/test_tilelang_language_vectorized_cast.py
  • tilelang/engine/param.py
  • tilelang/jit/adapter/tvm_ffi.py
  • tilelang/language/v2/dtypes.py
🧰 Additional context used
🧬 Code graph analysis (5)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
tilelang/language/v2/dtypes.py (5)
  • float4_e2m1fn (390-390)
  • float16 (299-299)
  • float32 (300-300)
  • float64 (301-301)
  • bfloat16 (397-397)
tilelang/language/v2/dtypes.py (1)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
tilelang/jit/adapter/tvm_ffi.py (2)
tilelang/language/v2/dtypes.py (1)
  • dtype (14-15)
tilelang/engine/param.py (1)
  • torch_dtype (127-141)
src/target/codegen_cuda.cc (1)
src/target/utils.cc (2)
  • IsCudaVectorizableFP8 (137-140)
  • IsCudaVectorizableFP8 (137-137)
src/tl_templates/cuda/cuda_fp4.h (2)
src/tl_templates/cuda/cuda_fp8.h (1)
  • float2 (294-302)
tilelang/language/v2/dtypes.py (1)
  • double (241-241)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (3)
tilelang/engine/param.py (1)

143-151: New tilelang_dtype accessor is consistent and useful

Exposing KernelParam.tilelang_dtype() as T.dtype(self.dtype) mirrors torch_dtype() and centralizes TVM→TileLang dtype conversion. This is a clean, low-risk addition.

src/target/codegen_cuda.cc (1)

975-1050: Vectorized cast and sub-byte load/store changes look consistent with FP4/FP8 design

  • Tightening the f16/bf16/FP8 vectorized cast conditions to explicitly require 32‑bit float sources/targets (target_ty.bits() == 32 / from_ty.bits() == 32) avoids accidentally applying these intrinsics to doubles or other float widths.
  • New FP4 paths (double↔float4_e2m1fn, bfloat16↔float4_e2m1fn) follow the established pattern:
    • Use 2‑lane chunks via PrintVectorizedCast.
    • Reinterpret FP4 vectors through uint8_t* in line with storage-based FP4 APIs.
    • Delegate actual numeric work to the helpers defined in cuda_fp4.h.
  • The adjusted ramp_lanes logic in BufferLoad/BufferStore ensures vector load/store recognition still works when the buffer element itself is a packed sub-byte vector type (e.g. FP4 packed into int8), rather than assuming element lanes==value lanes.

Overall the changes are coherent with the new FP4 storage model and vectorization strategy; the main follow-up is to ensure all used intrinsics exist with the expected signatures for the targeted CUDA/SM versions.

Also applies to: 1105-1150, 2858-2865, 2926-2933

src/tl_templates/cuda/cuda_fp4.h (1)

157-273: FP4 conversion helpers are symmetric and match the storage‑oriented design

The new FP4 helpers cover all expected pairs:

  • Half/half2, float/float2, double/double2, and bfloat16/bfloat162, all expressed in terms of __nv_fp4_storage_t / __nv_fp4x2_storage_t.
  • “To FP4” paths consistently use the CUDA __nv_cvt_*_to_fp4[x2] intrinsics with cudaRoundZero.
  • “From FP4” paths either:
    • Convert to half/half2 via *_to_halfraw* then reinterpret raw→typed, or
    • Build float/double(/bfloat16) results via the corresponding half/float/bfloat intrinsics.

This lines up with how codegen_cuda.cc calls these functions for vectorized casts and with the intended storage-dtype contract for sub-byte FP4 tensors.

@LeiWang1999 LeiWang1999 merged commit bea40bd into tile-ai:main Dec 24, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants