Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,10 @@ def run_regression_perf():
out_dtype, accum_dtype = "float32", "float32"
in_dtype = T.float8_e4m3fn
kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
print(kernel_e4m3.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Debug print statement in regression performance function.

The print(kernel_e4m3.get_kernel_source()) will output the full kernel source on every regression run, which can clutter CI logs. If this was added for debugging during development, consider removing it or making it conditional (e.g., via an environment variable or verbose flag).

Suggested fix
 def run_regression_perf():
     M, N, K = 4096, 4096, 4096
     out_dtype, accum_dtype = "float32", "float32"
     in_dtype = T.float8_e4m3fn
     kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
-    print(kernel_e4m3.get_kernel_source())
     profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
     latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
     return latency_e4m3
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(kernel_e4m3.get_kernel_source())
def run_regression_perf():
M, N, K = 4096, 4096, 4096
out_dtype, accum_dtype = "float32", "float32"
in_dtype = T.float8_e4m3fn
kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
return latency_e4m3
🤖 Prompt for AI Agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py` at line 230, Remove
the unconditional debug print of the kernel source causing noisy CI logs;
replace the direct call to print(kernel_e4m3.get_kernel_source()) with either
removal or a conditional/log-level gated output (e.g., only print when a VERBOSE
env var or a debug/verbose flag is set, or emit via a logger at DEBUG level).
Locate the call referencing kernel_e4m3.get_kernel_source() in the regression
performance function and ensure the kernel source is only output when explicitly
requested.

profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
in_dtype = T.float8_e5m2
kernel_e5m2 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
return latency_e4m3


if __name__ == "__main__":
Expand Down
38 changes: 38 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*!
* \file tl/config.h
* \brief TileLang configuration utilities.
*/

#ifndef TVM_TL_CONFIG_H_
#define TVM_TL_CONFIG_H_

#include <tvm/ir/transform.h>

namespace tvm {
namespace tl {
namespace tl_config {

/*!
* \brief Check if vectorize planner verbose output is enabled.
*/
inline bool VectorizePlannerVerboseEnabled() {
auto ctxt = transform::PassContext::Current();
return ctxt
->GetConfig("tl.enable_vectorize_planner_verbose", Optional<Bool>())
.value_or(Bool(false));
}

/*!
* \brief Check if 256-bit vectorization is disabled.
*/
inline bool Vectorize256Disabled() {
auto ctxt = transform::PassContext::Current();
return ctxt->GetConfig("tl.disable_vectorize_256", Optional<Bool>())
.value_or(Bool(false));
}

} // namespace tl_config
} // namespace tl
} // namespace tvm

#endif // TVM_TL_CONFIG_H_
148 changes: 97 additions & 51 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include "loop_vectorize.h"
#include "../config.h"
#include "../op/builtin.h"
#include "../op/utils.h"
#include "../target/utils.h"
Expand Down Expand Up @@ -166,15 +167,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
: arith::IRMutatorWithAnalyzer(analyzer), layout_map_(layout_map) {}

int Plan(const For &node) {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_vectorize_256 =
ctxt->GetConfig(kDisableVectorize256, Optional<Bool>());
bool disable_vectorize_256 =
opt_disable_vectorize_256.value_or(Bool(false));

Optional<Bool> opt_verbose =
ctxt->GetConfig(kEnableVectorizePlannerVerbose, Optional<Bool>());
bool verbose = opt_verbose.value_or(Bool(false));
bool disable_vectorize_256 = tl_config::Vectorize256Disabled();
bool verbose = tl_config::VectorizePlannerVerboseEnabled();

if (TargetSupportVectorize256(Target::Current(false)) &&
!disable_vectorize_256 &&
Expand Down Expand Up @@ -216,9 +210,11 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
<< "\n";
}

// Separate buffers into local/fragment vs memory (global/shared)
// Separate buffers into local/fragment vs memory (global/shared) vs
// call/cast
int local_fragment_min = initial_vector_size_;
int memory_min = initial_vector_size_;
int call_node_min = initial_vector_size_;
bool has_global_or_shared_buffer = false;

auto is_local_or_fragment = [](const Buffer &buf) {
Expand All @@ -236,43 +232,45 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
<< " -> vector_size=" << info.vector_size
<< (info.is_store ? " [store]" : " [load]") << "\n";
} else {
std::cerr << " [cast/extern/rng] -> vector_size=" << info.vector_size
std::cerr << " [cast/call] -> vector_size=" << info.vector_size
<< "\n";
}
}
if (!buffer.defined()) {
// CastNode, CallNode do not have buffer defined.
local_fragment_min =
arith::ZeroAwareGCD(local_fragment_min, info.vector_size);
memory_min = arith::ZeroAwareGCD(memory_min, info.vector_size);
call_node_min = arith::ZeroAwareGCD(call_node_min, info.vector_size);
} else if (is_local_or_fragment(buffer)) {
local_fragment_min =
arith::ZeroAwareGCD(local_fragment_min, info.vector_size);
local_fragment_buffers.push_back(info);
} else {
// global, shared, shared.dyn, or non-buffer constraints
// (cast/extern/rng)
// global, shared, shared.dyn
memory_min = arith::ZeroAwareGCD(memory_min, info.vector_size);
has_global_or_shared_buffer = true;
}
}

if (verbose) {
std::cerr << " Computed mins: local_fragment_min=" << local_fragment_min
<< ", memory_min=" << memory_min
<< ", call_node_min=" << call_node_min << "\n";
}

if (has_seq_stmt) {
// For body contains SeqStmt (multiple statements).
// Use conservative strategy: take GCD of all buffers including local.
// The special local buffer optimization only applies to simple single
// BufferStore cases where we can be confident about the access pattern.
vector_size_ = arith::ZeroAwareGCD(local_fragment_min, memory_min);
vector_size_ = arith::ZeroAwareGCD(
arith::ZeroAwareGCD(local_fragment_min, memory_min), call_node_min);
if (verbose) {
std::cerr << " [Strategy] Has SeqStmt, using conservative GCD of all: "
<< "local_fragment_min=" << local_fragment_min
<< ", memory_min=" << memory_min
std::cerr << " [Strategy] Has SeqStmt, using conservative GCD of all"
<< " -> vector_size=" << vector_size_ << "\n";
}
} else if (has_global_or_shared_buffer) {
// Has memory buffers and simple case (no SeqStmt):
// ignore local/fragment constraints
vector_size_ = memory_min;
vector_size_ = arith::ZeroAwareGCD(memory_min, call_node_min);
if (verbose) {
std::cerr << " [Strategy] Has memory buffers (simple case), using "
"memory_min="
Expand Down Expand Up @@ -306,12 +304,13 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
}
}
} else {
// Only local/fragment buffers: use their min
vector_size_ = local_fragment_min;
// Only local/fragment buffers: use GCD of local_fragment_min and
// call_node_min
vector_size_ = arith::ZeroAwareGCD(local_fragment_min, call_node_min);
if (verbose) {
std::cerr << " [Strategy] Only local/fragment buffers, using "
"local_fragment_min="
<< local_fragment_min << "\n";
"GCD(local_fragment_min, call_node_min)="
<< vector_size_ << "\n";
}
}

Expand Down Expand Up @@ -411,40 +410,67 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
// address_of have buffer load value so we should analysis the buffer load
// node to update vector_size_.
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} else if (node->op.same_as(tir::builtin::bitwise_and()) ||
node->op.same_as(tir::builtin::bitwise_or()) ||
node->op.same_as(tir::builtin::bitwise_xor()) ||
node->op.same_as(tir::builtin::bitwise_not()) ||
node->op.same_as(tir::builtin::shift_left()) ||
node->op.same_as(tir::builtin::shift_right())) {
// Bitwise operations can be vectorized
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
// for other call nodes, we should not apply vectorization
buffer_vector_infos_.push_back({Buffer(), 1, false, {}});

// For other call nodes, use PostOrderVisit to check buffer accesses
// and determine if the given vector size is invariant
auto check_buffer_access_invariant = [&](int target_vec_size) -> bool {
if (!inner_for_)
return true;
bool all_invariant = true;
PostOrderVisit(ffi::GetRef<PrimExpr>(node), [&](const ObjectRef &obj) {
if (!all_invariant)
return;
if (auto *load = obj.as<BufferLoadNode>()) {
auto transformed_indices =
TransformIndices(load->indices, load->buffer);
Array<PrimExpr> strides = GetBufferStrides(load->buffer);
PrimExpr elem_offset = 0;
for (size_t i = 0; i < transformed_indices.size(); ++i) {
elem_offset += transformed_indices[i] * strides[i];
}
if (!IsExprInvariantInVectorBoundary(elem_offset,
inner_for_->loop_var,
target_vec_size, analyzer_)) {
all_invariant = false;
}
} else if (auto *store = obj.as<BufferStoreNode>()) {
auto transformed_indices =
TransformIndices(store->indices, store->buffer);
Array<PrimExpr> strides = GetBufferStrides(store->buffer);
PrimExpr elem_offset = 0;
for (size_t i = 0; i < transformed_indices.size(); ++i) {
elem_offset += transformed_indices[i] * strides[i];
}
if (!IsExprInvariantInVectorBoundary(elem_offset,
inner_for_->loop_var,
target_vec_size, analyzer_)) {
all_invariant = false;
}
}
});
return all_invariant;
};

// Find the largest vector size where all buffer accesses are invariant
int call_node_vector_size = loop_extent_vector_size_;
while (call_node_vector_size > 1) {
if (check_buffer_access_invariant(call_node_vector_size)) {
break;
}
call_node_vector_size /= 2;
}
buffer_vector_infos_.push_back(
{Buffer(), call_node_vector_size, false, {}});
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here
}

PrimExpr VisitExpr_(const CastNode *node) final {
int cast_vector_size = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), initial_vector_size_);
// Record cast constraint (use empty buffer to indicate cast)
buffer_vector_infos_.push_back({Buffer(), cast_vector_size, false, {}});
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

int ComputeBufferVectorSize(const Array<PrimExpr> &indices,
const Buffer &buffer, bool is_store) {
if (!inner_for_)
return initial_vector_size_;

int buffer_vec_size = loop_extent_vector_size_;

// Transform indices using layout_map if present
Array<PrimExpr> TransformIndices(const Array<PrimExpr> &indices,
const Buffer &buffer) {
auto transformed_indices = indices;
if (layout_map_.defined() && layout_map_.count(buffer)) {
ICHECK(IsBufferContiguous(buffer, analyzer_))
Expand Down Expand Up @@ -476,6 +502,26 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
Array<PrimExpr>{new_indices.rbegin(), new_indices.rend()};
}
}
return transformed_indices;
}

PrimExpr VisitExpr_(const CastNode *node) final {
int cast_vector_size = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), initial_vector_size_);
// Record cast constraint (use empty buffer to indicate cast)
buffer_vector_infos_.push_back({Buffer(), cast_vector_size, false, {}});
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

int ComputeBufferVectorSize(const Array<PrimExpr> &indices,
const Buffer &buffer, bool is_store) {
if (!inner_for_)
return initial_vector_size_;

int buffer_vec_size = loop_extent_vector_size_;

// Transform indices using layout_map if present
auto transformed_indices = TransformIndices(indices, buffer);

// 1. Compute raw element offset
Array<PrimExpr> strides = GetBufferStrides(buffer);
Expand Down
33 changes: 33 additions & 0 deletions testing/python/language/test_tilelang_language_vectorize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import tilelang.testing
import tilelang.language as T

from tilelang.intrinsics import make_mma_swizzle_layout
import pytest


Expand Down Expand Up @@ -165,5 +167,36 @@ def test_vectorize_broadcast_int8(vec_num):
vectorize_broadcast_int8.compile(vec_num=vec_num)


@tilelang.jit
def vectorize_test_call_infinity():
A = T.empty((4,), dtype=T.float32)
with T.Kernel(1, threads=128):
for i in T.vectorized(4):
A[i] = T.infinity(T.float32)
return A


def test_vectorize_call_infinity():
kernel = vectorize_test_call_infinity.compile()
assert "float4" in kernel.get_kernel_source()


@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_VECTORIZE_PLANNER_VERBOSE: True})
def vectorize_test_call_bitwise_logical():
A = T.empty((128, 32), dtype=T.float32)
with T.Kernel(1, threads=128):
A_shared = T.alloc_shared((128, 32), dtype=T.float32)
T.annotate_layout({A_shared: make_mma_swizzle_layout(A_shared)})
for i, j in T.Parallel(128, 32):
A_shared[i, j] = A[i, j]
return A


def test_vectorize_call_bitwise_logical():
kernel = vectorize_test_call_bitwise_logical.compile()
print(kernel.get_kernel_source())
assert "float4" in kernel.get_kernel_source()


if __name__ == "__main__":
tilelang.testing.main()
Loading