Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 96 additions & 10 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// half8 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__half22float2(*((half2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__half22float2(*((half2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
Expand All @@ -939,6 +955,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> half8
PrintIndent();
stream << "((half2*)(&" << sret << "))[0] = "
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[1] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[2] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[3] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
}

Expand All @@ -965,6 +997,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// bfloat162x4 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+3));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
Expand All @@ -985,6 +1037,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> bfloat162x4
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
}

Expand Down Expand Up @@ -1019,6 +1087,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> fp8x8
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+2), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+3), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
}
}

Expand Down Expand Up @@ -2345,16 +2441,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = PrintExpr(call->args[0]);
this->PrintIndent();
stream << "device_assert(" << cond << ");\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = PrintExpr(call->args[0]);
std::string msg_expr = PrintExpr(call->args[1]);
this->PrintIndent();
stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
} else {
CodeGenC::VisitStmt_(op);
}
Expand Down
17 changes: 15 additions & 2 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
}
}
// Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) {
if (reg_num < min_reg_num ||
(reg_num == min_reg_num &&
attempt_infer_root < min_reg_num_infer_root)) {
best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map;
Expand Down Expand Up @@ -787,7 +789,18 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
}
});

if (has_non_local && !has_reducer) {
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto* store = obj.as<BufferStoreNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
has_cast_operations = true;
}
}
});

if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node);
}

Expand Down
32 changes: 29 additions & 3 deletions testing/python/language/test_tilelang_language_vectorized_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,36 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):

@T.prim_func
def main(
A: T.Tensor[(M), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
Comment on lines +20 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa F821 directives.

They’re unnecessary with (M,) annotations and trip Ruff.

-            A: T.Tensor[(M,), dtype_A],  # noqa: F821
-            B: T.Tensor[(M,), dtype_B],  # noqa: F821
+            A: T.Tensor[(M,), dtype_A],
+            B: T.Tensor[(M,), dtype_B],

Apply the same change in the parallel kernel.

Also applies to: 35-36

🧰 Tools
🪛 Ruff (0.14.2)

20-20: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 20-21 (and also lines 35-36), remove the redundant "# noqa: F821"
annotations on the T.Tensor[(M,), dtype_...] lines — they are unnecessary with
the (M,) annotations and duplicate Ruff coverage; do the same edit in the
parallel kernel test file equivalent locations as noted.

):
with T.Kernel(1, threads=128):
T.copy(A, B)

return main


@tilelang.jit
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
assert M % 256 == 0

@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A)
B_local = T.alloc_fragment((M,), dtype_B)

T.copy(A, A_local)
for i in T.Parallel(M):
B_local[i] = A_local[i]
T.copy(B_local, B)

return main


def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness.
Args:
Expand All @@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,

M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)

A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()

kernel(A, B)
kernel_parallel(A, C)

torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)

code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()

assert check_str in code, \
assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"


Expand Down
Loading