diff --git a/src/support/ffi_aliases.h b/src/support/ffi_aliases.h index cbc6fb027..7dbe0b395 100644 --- a/src/support/ffi_aliases.h +++ b/src/support/ffi_aliases.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include diff --git a/src/transform/legalize_negative_index.cc b/src/transform/legalize_negative_index.cc index b502a6fba..f0df555ef 100644 --- a/src/transform/legalize_negative_index.cc +++ b/src/transform/legalize_negative_index.cc @@ -1,6 +1,6 @@ /*! * \file legalize_negative_index.cc - * \brief Legalize negative indices in buffer load expressions. + * \brief Legalize negative indices in buffer load/store expressions. */ #include @@ -10,6 +10,7 @@ #include #include +#include #include #include "arith/ir_mutator_with_analyzer.h" @@ -23,47 +24,42 @@ using arith::IRVisitorWithAnalyzer; enum class IndexSignState { kNonNegative, kNegative, kUnknown }; +using BufferAccessVariant = + std::variant; +using LoadStore2StateMap = + std::unordered_map>; + class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { public: - explicit NegativeIndexAnalyzer( - std::unordered_map> - *result) + explicit NegativeIndexAnalyzer(LoadStore2StateMap *result) : result_(result) {} - void VisitExpr_(const BufferLoadNode *op) final { - auto load = tvm::ffi::GetRef(op); +private: + std::vector ProcessIdx(const ffi::Array &indices, + ffi::String buffer_name) { std::vector states; - states.reserve(op->indices.size()); - bool needs_record = false; + states.reserve(indices.size()); - for (size_t i = 0; i < op->indices.size(); ++i) { - PrimExpr simplified = analyzer_.Simplify(op->indices[i]); + for (size_t i = 0; i < indices.size(); ++i) { + PrimExpr simplified = analyzer_.Simplify(indices[i]); + IndexSignState state = IndexSignState::kUnknown; // Handle scalar indices with the standard analyzer if (simplified.dtype().lanes() == 1) { - if (analyzer_.CanProve(simplified >= 0)) { - states.push_back(IndexSignState::kNonNegative); - continue; - } - if (analyzer_.CanProve(simplified < 0)) { - states.push_back(IndexSignState::kNegative); - needs_record = true; - continue; - } - states.push_back(IndexSignState::kUnknown); - needs_record = true; - DLOG(WARNING) - << "LegalizeNegativeIndex: cannot prove non-negative index " - << simplified << " for buffer " << load->buffer->name << " (axis " - << i << ")."; - continue; + if (analyzer_.CanProve(simplified >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(simplified < 0)) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; } - // Vector indices: try to reason about non-negativity/negativity // Common patterns are Ramp(base, stride, lanes) and Broadcast(value, // lanes). - IndexSignState vec_state = IndexSignState::kUnknown; - if (const auto *ramp = simplified.as()) { + else if (const auto *ramp = simplified.as()) { // Compute a safe lower/upper bound for the vector lanes // lower_bound = base_min + min(0, stride_min) * (lanes - 1) // upper_bound = base_max + max(0, stride_max) * (lanes - 1) @@ -85,118 +81,129 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { if (s_max > 0) upper += s_max * (lanes - 1); - if (lower >= 0) { - vec_state = IndexSignState::kNonNegative; - } else if (upper < 0) { - vec_state = IndexSignState::kNegative; - } else { - vec_state = IndexSignState::kUnknown; - } - } else if (const auto *bc = simplified.as()) { - auto v = analyzer_.Simplify(bc->value); - if (analyzer_.CanProve(v >= 0)) { - vec_state = IndexSignState::kNonNegative; - } else if (analyzer_.CanProve(v < 0)) { - vec_state = IndexSignState::kNegative; - } else { + if (lower >= 0) + state = IndexSignState::kNonNegative; + else if (upper < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; + } else if (const auto *broadcast = simplified.as()) { + auto v = analyzer_.Simplify(broadcast->value); + if (analyzer_.CanProve(v >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(v < 0)) + state = IndexSignState::kNegative; + else { // Try const bound if proof unavailable auto vb = analyzer_.const_int_bound(v); - if (vb->min_value >= 0) { - vec_state = IndexSignState::kNonNegative; - } else if (vb->max_value < 0) { - vec_state = IndexSignState::kNegative; - } else { - vec_state = IndexSignState::kUnknown; - } + if (vb->min_value >= 0) + state = IndexSignState::kNonNegative; + else if (vb->max_value < 0) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; } } + states.push_back(state); + } - if (vec_state == IndexSignState::kNonNegative) { - states.push_back(IndexSignState::kNonNegative); - continue; - } - if (vec_state == IndexSignState::kNegative) { - states.push_back(IndexSignState::kNegative); - needs_record = true; - continue; - } + return std::move(states); + } - states.push_back(IndexSignState::kUnknown); - needs_record = true; - DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " - << simplified << " for buffer " << load->buffer->name - << " (axis " << i << ")."; - } + bool NeedRecord(const std::vector &states) { + return std::any_of(states.begin(), states.end(), + [](const IndexSignState &state) { + return state == IndexSignState::kUnknown || + state == IndexSignState::kNegative; + }); + } + + void VisitExpr_(const BufferLoadNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); - if (needs_record) { + if (NeedRecord(states)) (*result_)[op] = std::move(states); - } IRVisitorWithAnalyzer::VisitExpr_(op); } + void VisitStmt_(const BufferStoreNode *op) final { + std::vector states = + ProcessIdx(op->indices, op->buffer->name); + + if (NeedRecord(states)) + (*result_)[op] = std::move(states); + + IRVisitorWithAnalyzer::VisitStmt_(op); + } + private: - std::unordered_map> - *result_; + LoadStore2StateMap *result_; }; class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { public: - static PrimFunc - Apply(PrimFunc func, - const std::unordered_map> &states) { + static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) { arith::Analyzer analyzer; NegativeIndexRewriter rewriter(&analyzer, states); - if (!func->body.defined()) { - return func; - } PrimFuncNode *func_node = func.CopyOnWrite(); func_node->body = rewriter.VisitStmt(func_node->body); return func; } private: - NegativeIndexRewriter( - arith::Analyzer *analyzer, - const std::unordered_map> &states) + NegativeIndexRewriter(arith::Analyzer *analyzer, + const LoadStore2StateMap &states) : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} + ffi::Array UpdateIdx(const ffi::Array &indices, + const ffi::Array &buffer_shape, + const std::vector &state_vec) { + ICHECK_EQ(state_vec.size(), indices.size()) + << "State vector size mismatch for buffer load/store indices (" + << indices << ")"; + ffi::Array new_indices = indices; + for (size_t i = 0; i < indices.size(); ++i) { + if (state_vec[i] != IndexSignState::kNegative) + continue; + new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i])); + } + return new_indices; + } + PrimExpr VisitExpr_(const BufferLoadNode *op) final { BufferLoad load = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); auto it = states_.find(op); - if (it == states_.end()) { + if (it == states_.end()) return load; - } - auto indices = load->indices; - bool changed = false; - - const auto &state_vector = it->second; - ICHECK_EQ(state_vector.size(), indices.size()) - << "State vector size mismatch for buffer load " << load->buffer->name; + auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second); + return BufferLoad(load->buffer, indices, load->predicate); + } - for (size_t i = 0; i < indices.size(); ++i) { - if (state_vector[i] != IndexSignState::kNegative) { - continue; - } - PrimExpr extent = load->buffer->shape[i]; - indices.Set(i, analyzer_->Simplify(extent + indices[i])); - changed = true; - } + Stmt VisitStmt_(const BufferStoreNode *op) final { + BufferStore store = + Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); - if (!changed) { - return load; - } + auto it = states_.find(op); + if (it == states_.end()) + return store; - return BufferLoad(load->buffer, indices); + auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second); + return BufferStore(store->buffer, store->value, indices, store->predicate); } - const std::unordered_map> - &states_; +private: + const LoadStore2StateMap &states_; }; PrimFunc LegalizeNegativeIndex(PrimFunc func) { @@ -204,8 +211,7 @@ PrimFunc LegalizeNegativeIndex(PrimFunc func) { return func; } - std::unordered_map> - states; + LoadStore2StateMap states; NegativeIndexAnalyzer analyzer(&states); analyzer(func->body); if (states.empty()) { diff --git a/testing/python/transform/test_tilelang_transform_legalize_negative_index.py b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py new file mode 100644 index 000000000..c5dd065aa --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py @@ -0,0 +1,342 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, expected): + """Helper function to verify structural equality after transformations""" + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.LegalizeNegativeIndex()(mod) + expected = tvm.IRModule.from_expr(expected.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], expected["main"], True) + + +def test_buffer_load_negative_index_legalized(): + """ + Test that negative indices are legalized by adding buffer extent. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + value = A[-1] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + value = A[1023] # A[-1] becomes A[1023] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices - only negative ones are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + value = A[-1, 10] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + value = A[1023, 10] # A[-1, 10] becomes A[1023, 10] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), "float32")): + value = A[-1, -2, -3] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), "float32")): + value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253 + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_negative_index_in_expression(): + """ + Test negative index as part of a larger expression. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + B = T.alloc_buffer((1024,), "float32") + for i in T.serial(1, 1024): + value = A[-i] + B[-i] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + B = T.alloc_buffer((1024,), "float32") + for i in T.serial(1, 1024): + value = A[1024 - i] + B[1024 - i] = value + + _check(before, after) + + +def test_buffer_load_non_negative_index_unchanged(): + """ + Test that non-negative indices remain unchanged. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + value = A[0] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # No changes expected for non-negative indices + value = A[0] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_unknown_sign_index_warning(): + """ + Test that indices with unknown sign trigger warnings but are processed. + This test mainly checks that the pass doesn't crash on unknown signs. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + i = T.Var("i", "int32") + value = A[i] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + i = T.Var("i", "int32") + # Unknown sign indices should remain unchanged + value = A[i] + B = T.alloc_buffer((1,), "float32") + B[0] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Broadcast(-1, 4) + value = A[vec] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + value = A[T.Broadcast(1023, 4)] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_vector_index_negative_ramp(): + """ + Test negative indices in vectorized operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + value = A[vec] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + value = A[T.Ramp(1020, 1, 4)] + B = T.alloc_buffer((4,), "float32") + B[T.Ramp(0, 1, 4)] = value + + _check(before, after) + + +def test_buffer_load_nested_buffer_loads(): + """ + Test legalization with nested buffer load expressions. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + inner_val = A[-1, 10] + outer_val = A[inner_val.astype("int32"), -2] + B = T.alloc_buffer((1,), "float32") + B[0] = outer_val + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + inner_val = A[1023, 10] + outer_val = A[inner_val.astype("int32"), 510] + B = T.alloc_buffer((1,), "float32") + B[0] = outer_val + + _check(before, after) + + +def test_buffer_store_negative_index(): + """ + Test negative indices in buffer store operations are legalized. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + A[-1] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + A[1023] = 42.0 + + _check(before, after) + + +def test_buffer_store_mixed_negative_positive_indices(): + """ + Test mixed negative and positive indices in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512), "float32")): + A[-1, 10] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512), "float32")): + A[1023, 10] = 42.0 + + _check(before, after) + + +def test_buffer_store_multiple_negative_indices(): + """ + Test multiple negative indices in different dimensions for buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024, 512, 256), "float32")): + A[-1, -2, -3] = 42.0 + + @T.prim_func + def after(A: T.Tensor((1024, 512, 256), "float32")): + A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253 + + _check(before, after) + + +def test_buffer_store_negative_index_in_expression(): + """ + Test negative index as part of a larger expression in buffer store. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + for i in T.serial(1, 1024): + A[-i] = i * 2.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + for i in T.serial(1, 1024): + A[1024 - i] = i * 2.0 + + _check(before, after) + + +def test_buffer_store_vector_index_negative_broadcast(): + """ + Test negative indices in vectorized store operations (broadcast case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Broadcast(-1, 4) + values = T.Broadcast(42.0, 4) + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Broadcast(-1, 4) # noqa: F841 + values = T.Broadcast(42.0, 4) + A[T.Broadcast(1023, 4)] = values + + _check(before, after) + + +def test_buffer_store_vector_index_negative_ramp(): + """ + Test negative indices in vectorized store operations (ramp case). + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32")): + vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] + values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0] + A[vec] = values + + @T.prim_func + def after(A: T.Tensor((1024,), "float32")): + # vec is unused and can be delimed by Simplify. + vec = T.Ramp(-4, 1, 4) # noqa: F841 + values = T.Ramp(0.0, 1.0, 4) + A[T.Ramp(1020, 1, 4)] = values + + _check(before, after) + + +def test_buffer_store_nested_in_condition(): + """ + Test negative index buffer store within conditional statements. + """ + + @T.prim_func + def before(A: T.Tensor((1024,), "float32"), flag: T.int32): + if flag > 0: + A[-1] = 42.0 + else: + A[-2] = 24.0 + + @T.prim_func + def after(A: T.Tensor((1024,), "float32"), flag: T.int32): + if flag > 0: + A[1023] = 42.0 + else: + A[1022] = 24.0 + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main()