diff --git a/src/transform/simplify.cc b/src/transform/simplify.cc index 6f6fabe75..10f8b3a35 100644 --- a/src/transform/simplify.cc +++ b/src/transform/simplify.cc @@ -32,6 +32,7 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { bool propagate_knowns_to_simplify_expressions{}; bool convert_boolean_to_and_of_ors{}; bool apply_constraints_to_boolean_branches{}; + bool enable_simplify_let_inline{true}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -61,7 +62,11 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { "If true, simplify each branch of AND/OR under a constraints " "provided by the other " "branch", - refl::DefaultValue(false)); + refl::DefaultValue(false)) + .def_ro("enable_simplify_let_inline", + &SimplifyConfigNode::enable_simplify_let_inline, + "If true, inline let statements when possible", + refl::DefaultValue(true)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig", SimplifyConfigNode, BaseAttrsNode); @@ -323,6 +328,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } bool CanInlineLetStmt(const LetStmtNode *op) { + if (!config_->enable_simplify_let_inline) + return false; if (is_const_number(op->value)) return true; if (op->value.as()) diff --git a/testing/python/transform/test_tilelang_transform_simplify.py b/testing/python/transform/test_tilelang_transform_simplify.py index 3b7376820..1bf08d40c 100644 --- a/testing/python/transform/test_tilelang_transform_simplify.py +++ b/testing/python/transform/test_tilelang_transform_simplify.py @@ -1,90 +1,605 @@ +# ruff: noqa from tilelang import tvm as tvm import tilelang as tl import tilelang.language as T import tilelang.testing +from tilelang.transform import PassConfigKey +from tvm import te -def modify( - with_B: bool = False, - with_bias: bool = False, -): - @T.prim_func - def main( - A: T.Tensor((64, 64)), - B: T.Tensor((64, 64)), - C: T.Tensor((64, 64)), - D: T.Tensor((64, 64)), - bias: T.Tensor((64, 64)), - ): - if with_B: - if with_bias: - T.gemm(A, bias, D) - T.gemm(A, B, D) - else: - with T.block(): - A_shared = T.alloc_shared((64, 64), dtype=T.float32) - C_shared = T.alloc_shared((64, 64), dtype=T.float32) - D_shared = T.alloc_shared((64, 64), dtype=T.float32) - T.copy(A, A_shared) - T.copy(C, C_shared) - T.gemm(A_shared, C_shared, D_shared) - T.copy(D_shared, D) - - return main - - -def test_modify(with_B=False, with_bias=False): - tester = modify(with_B=with_B, with_bias=with_bias) - mod = tvm.IRModule({tester.attrs["global_symbol"]: tester}) - mod2 = tl.transform.Simplify()(mod) - assert mod != mod2 - - -def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): - @T.prim_func - def main( - a: T.handle, - b: T.handle, - c: T.handle, - ): - A = T.match_buffer(a, (M, K), dtype=dtype) - B = T.match_buffer(b, (K, N), dtype=dtype) - C = T.match_buffer(c, (M, N), dtype=accum_dtype) - - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def test_matmul(): - func = matmul(1024, 1024, 1024, 128, 128, 32) - mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - mod = tl.transform.Simplify()(mod) - kernel = tl.compile(mod["main"], out_idx=[2]) - - import torch - - a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() - b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() - c = kernel(a, b) - - ref_c = a @ b - ref_c = ref_c.float() - torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) - - # Get CUDA Source - print(kernel.get_kernel_source()) + +def simplify_and_compare(before, expected, config=None): + """Helper function to run simplify pass and compare results.""" + if config is None: + config = {} + + full_config = {PassConfigKey.TL_SIMPLIFY.value: config} + + with tvm.transform.PassContext(config=full_config): + after = tl.transform.Simplify()(before) + + # Compare bodies only, ignoring function name differences + # Use map_free_vars=True to allow mapping of free variables (function parameters) + after_func = after["main"] + expected_func = expected["main"] + tvm.ir.assert_structural_equal(after_func.body, expected_func.body, map_free_vars=True) + + +def test_stmt_simplify(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + n = te.size_var("n") + with ib.for_range(0, n, name="i") as i, ib.if_scope(i < 12): + A[i] = C[i] + + body = tvm.tir.LetStmt(n, 10, ib.get()) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) + body = tl.transform.Simplify()(mod)["main"].body + assert isinstance(body.body, tvm.tir.BufferStore) + + +def test_thread_extent_simplify(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + n = te.size_var("n") + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + ib.scope_attr(tx, "thread_extent", n) + ib.scope_attr(tx, "thread_extent", n) + ib.scope_attr(ty, "thread_extent", 1) + with ib.if_scope(tx + ty < 12): + A[tx] = C[tx + ty] + body = tvm.tir.LetStmt(n, 10, ib.get()) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) + body = tl.transform.Simplify()(mod)["main"].body + assert isinstance(body.body.body.body, tvm.tir.BufferStore) + + +def test_if_likely(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + n = te.size_var("n") + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + ib.scope_attr(tx, "thread_extent", 32) + ib.scope_attr(ty, "thread_extent", 32) + with ib.if_scope(ib.likely(tx * 32 + ty < n)), ib.if_scope(ib.likely(tx * 32 + ty < n)): + A[tx] = C[tx * 32 + ty] + body = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) + body = tl.transform.Simplify()(mod)["main"].body + assert isinstance(body.body.body, tvm.tir.IfThenElse) + assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse) + + +def test_load_store_noop(): + """Store of a value that was just read from the same location is a no-op.""" + + @T.prim_func + def before(A: T.Buffer((1,), "float32")): + A[0] = A[0] + + @T.prim_func + def expected(A: T.Buffer((1,), "float32")): + T.evaluate(0) + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_load_store_noop_after_simplify(): + """As test_load_store_noop, but requiring simplification to identify.""" + + @T.prim_func + def before(A: T.Buffer((1,), "float32")): + A[0] = A[0] + (5.0 - 5.0) + + @T.prim_func + def expected(A: T.Buffer((1,), "float32")): + T.evaluate(0) + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_nested_condition(): + """Nested IfThenElse with the same condition can be simplified.""" + + @T.prim_func + def before(A: T.Buffer((16,), "float32")): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32")): + for i in T.serial(16): + if i == 5: + A[i] = 0.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_nested_provable_condition(): + """Simplify inner conditional using constraint from outer.""" + + @T.prim_func + def before(A: T.Buffer((16,), "float32")): + for i in T.serial(16): + if i == 5: + if i < 7: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32")): + for i in T.serial(16): + if i == 5: + A[i] = 0.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_nested_var_condition(): + """Simplify inner conditional using constraint from outer.""" + + @T.prim_func + def before(A: T.Buffer((16,), "float32"), n: T.int32): + for i in T.serial(16): + if i == n: + if i == n: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), n: T.int32): + for i in T.serial(16): + if i == n: + A[i] = 0.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_altered_buffer_contents(): + """No simplification of data-dependent conditionals.""" + + @T.prim_func + def before(A: T.Buffer((1,), "int32"), n: T.int32): + if A[0] == n: + A[0] = A[0] + 1 + if A[0] == n: + A[0] = 0 + + mod_before = tvm.IRModule({"main": before}) + # Expected is the same as before + simplify_and_compare(mod_before, mod_before) + + +def test_negation_of_condition(): + """Use negation of outer condition to simplify inner.""" + + @T.prim_func + def before(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + if i == 5: + if i != 5: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + if i == 5: + A[i] = 1 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_negation_of_not_equal(): + """Test negation with != outer condition.""" + + @T.prim_func + def before(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + if i != 5: + if i == 5: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + if i != 5: + A[i] = 1 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_negation_of_var_condition(): + """Test negation with dynamic condition.""" + + @T.prim_func + def before(A: T.Buffer((16,), "int32"), n: T.int32): + for i in T.serial(16): + if i == n: + if i != n: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer((16,), "int32"), n: T.int32): + for i in T.serial(16): + if i == n: + A[i] = 1 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_literal_constraint_split_boolean_and(): + """Split a boolean AND into independent constraints.""" + + @T.prim_func + def before(A: T.Buffer((16, 16), "int32"), n: T.int32): + for i, j in T.grid(16, 16): + if i == n and j == n: + if i == n: + A[i, j] = 0 + + @T.prim_func + def expected(A: T.Buffer((16, 16), "int32"), n: T.int32): + for i, j in T.grid(16, 16): + if i == n and j == n: + A[i, j] = 0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_literal_constraint_split_boolean_or(): + """Split a boolean OR into independent constraints.""" + + @T.prim_func + def before(A: T.Buffer((16, 16), "int32"), n: T.int32): + for i, j in T.grid(16, 16): + if i == n or j == n: + A[i, j] = 0 + else: + if i == n: + A[i, j] = 1 + else: + A[i, j] = 2 + + @T.prim_func + def expected(A: T.Buffer((16, 16), "int32"), n: T.int32): + for i, j in T.grid(16, 16): + if i == n or j == n: + A[i, j] = 0 + else: + A[i, j] = 2 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_if_then_else_expr(): + @T.prim_func + def before(A: T.Buffer(16, "float32")): + for i in T.serial(16): + if i < 12: + A[i] = T.if_then_else(i < 12, 1.0, 2.0, dtype="float32") + + @T.prim_func + def expected(A: T.Buffer(16, "float32")): + for i in T.serial(16): + if i < 12: + A[i] = 1.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_ceil_log2_int(): + """Simplify expressions resulting from topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer(1, "int32")): + A[0] = T.cast(T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32") + + @T.prim_func + def expected(A: T.Buffer(1, "int32")): + A[0] = 4 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_left_shift_lower_bound(): + """Integer bounds are propagated through left shift.""" + + @T.prim_func + def before(A: T.Buffer(16, "float32")): + for i in T.serial(16): + if T.shift_left(1, i, dtype="int32") >= 1: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer(16, "float32")): + for i in T.serial(16): + A[i] = 0.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_left_shift_upper_bound(): + """Integer bounds are propagated through left shift.""" + + @T.prim_func + def before(A: T.Buffer(16, "float32")): + for i in T.serial(16): + if T.shift_left(31, i, dtype="int32") <= 1015808: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer(16, "float32")): + for i in T.serial(16): + A[i] = 0.0 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_conditional_floor_mod(): + """A regression test for negative floormod denominator.""" + + @T.prim_func + def before(A: T.Buffer(1, "bool"), i: T.int32): + if T.floormod(0 - i, 2) == 0: + A[0] = T.floormod(i, 2) == 0 + + @T.prim_func + def expected(A: T.Buffer(1, "bool"), i: T.int32): + if T.floormod(i, -2) == 0: + A[0] = True + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_simplify_rhs_of_boolean_and_using_lhs(): + """Boolean expressions can introduce contexts.""" + + @T.prim_func + def before(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 5 and n < 10 + + @T.prim_func + def expected(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 5 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_APPLY_CONSTRAINTS_TO_BOOLEAN_BRANCHES.value: True}) + + +def test_simplify_lhs_of_boolean_and_using_rhs(): + """Boolean expressions can introduce contexts for their arguments.""" + + @T.prim_func + def before(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 10 and n < 5 + + @T.prim_func + def expected(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 5 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_APPLY_CONSTRAINTS_TO_BOOLEAN_BRANCHES.value: True}) + + +def test_simplify_rhs_of_boolean_or_using_lhs(): + """Boolean expressions can introduce contexts.""" + + @T.prim_func + def before(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 10 or n < 5 + + @T.prim_func + def expected(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 10 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_APPLY_CONSTRAINTS_TO_BOOLEAN_BRANCHES.value: True}) + + +def test_simplify_lhs_of_boolean_or_using_rhs(): + """Boolean expressions can introduce contexts for their arguments.""" + + @T.prim_func + def before(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 5 or n < 10 + + @T.prim_func + def expected(A: T.Buffer(1, "bool"), n: T.int32): + A[0] = n < 10 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_APPLY_CONSTRAINTS_TO_BOOLEAN_BRANCHES.value: True}) + + +def test_simplify_conditional_using_buffer_value(): + """Simplify a conditional using the known value in the buffer.""" + + @T.prim_func + def before(A: T.Buffer(1, "int32")): + A[0] = 0 + if A[0] == 0: + A[0] = 42 + + @T.prim_func + def expected(A: T.Buffer(1, "int32")): + A[0] = 0 + A[0] = 42 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_PROVE_CONDITIONAL.value: True}) + + +def test_simplify_non_conditional(): + """Propagate a known value to later expressions.""" + + @T.prim_func + def before(A: T.Buffer(1, "int32")): + A[0] = 0 + A[0] = A[0] + 1 + + @T.prim_func + def expected(A: T.Buffer(1, "int32")): + A[0] = 0 + A[0] = 1 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_SIMPLIFY_EXPRESSIONS.value: True}) + + +def test_suppress_simplify_non_conditional(): + """Propagate a known value to later expressions - disabled.""" + + @T.prim_func + def before(A: T.Buffer(1, "int32")): + A[0] = 0 + A[0] = A[0] + 1 + + mod_before = tvm.IRModule({"main": before}) + simplify_and_compare(mod_before, mod_before, {PassConfigKey.TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_SIMPLIFY_EXPRESSIONS.value: False}) + + +def test_simplify_buffer_store(): + """Simplification using prior known.""" + + @T.prim_func + def before(A: T.Buffer(1, "int32")): + A[0] = 5 + A[0] = A[0] + 7 + + @T.prim_func + def expected(A: T.Buffer(1, "int32")): + A[0] = 5 + A[0] = 12 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_SIMPLIFY_EXPRESSIONS.value: True}) + + +def test_rewrite_as_and_of_ors(): + """If enabled, rewrite boolean expressions into AND of OR.""" + + @T.prim_func + def before(A: T.Buffer(3, "bool")): + T.evaluate(A[0] or (A[1] and A[2])) + + @T.prim_func + def expected(A: T.Buffer(3, "bool")): + T.evaluate((A[0] or A[1]) and (A[0] or A[2])) + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_CONVERT_BOOLEAN_TO_AND_OF_ORS.value: True}) + + +def test_suppress_rewrite_as_and_of_ors(): + """Only rewrite into AND of OR when allowed.""" + + @T.prim_func + def before(A: T.Buffer(3, "bool")): + T.evaluate(A[0] or (A[1] and A[2])) + + mod_before = tvm.IRModule({"main": before}) + simplify_and_compare(mod_before, mod_before, {PassConfigKey.TL_SIMPLIFY_CONVERT_BOOLEAN_TO_AND_OF_ORS.value: False}) + + +def test_buffer_shape_constraint(): + @T.prim_func + def before(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32,), "float32") + A[T.min(T.int64(0), n)] = T.float32(0) + + @T.prim_func + def expected(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32,), "float32") + A[T.int64(0)] = T.float32(0) + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + simplify_and_compare(mod_before, mod_expected) + + +def test_tilelang_enable_simplify_let_inline_true(): + """Test that let statements are inlined when tilelang_enable_simplify_let_inline=True (default).""" + + @T.prim_func + def before(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + x = i + 1 + A[i] = x + + @T.prim_func + def expected(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + A[i] = i + 1 + + mod_before = tvm.IRModule({"main": before}) + mod_expected = tvm.IRModule({"main": expected}) + # Default behavior: let statements are inlined + simplify_and_compare(mod_before, mod_expected, {PassConfigKey.TL_SIMPLIFY_ENABLE_LET_INLINE.value: True}) + + +def test_tilelang_enable_simplify_let_inline_false(): + """Test that let statements are NOT inlined when tilelang_enable_simplify_let_inline=False.""" + + @T.prim_func + def before(A: T.Buffer((16,), "int32")): + for i in T.serial(16): + x = i + 1 + A[i] = x + + mod_before = tvm.IRModule({"main": before}) + # When disabled, let statements should be preserved (before == after) + simplify_and_compare(mod_before, mod_before, {PassConfigKey.TL_SIMPLIFY_ENABLE_LET_INLINE.value: False}) if __name__ == "__main__": diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 502635a7b..5a1fd61de 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -8,7 +8,41 @@ class PassConfigKey(str, Enum): # TileLang specific configs TL_SIMPLIFY = "tl.Simplify" - """Enable/disable TileLang simplification passes. Default: True""" + """Configuration for TileLang simplification passes. + + This is a dict-based config with the following options: + - transitively_prove_inequalities: bool, default False + - convert_boolean_to_and_of_ors: bool, default False + - apply_constraints_to_boolean_branches: bool, default False + - propagate_knowns_to_prove_conditional: bool, default False + - propagate_knowns_to_simplify_expressions: bool, default False + - enable_simplify_let_inline: bool, default True + + Usage: + with tvm.transform.PassContext(config={ + "tl.Simplify": {"enable_simplify_let_inline": False} + }): + mod = tl.transform.Simplify()(mod) + """ + + # TL_SIMPLIFY sub-config keys + TL_SIMPLIFY_TRANSITIVELY_PROVE_INEQUALITIES = "transitively_prove_inequalities" + """Enable transitive inequality proving in simplification. Default: False""" + + TL_SIMPLIFY_CONVERT_BOOLEAN_TO_AND_OF_ORS = "convert_boolean_to_and_of_ors" + """Convert boolean expressions to AND of ORs form. Default: False""" + + TL_SIMPLIFY_APPLY_CONSTRAINTS_TO_BOOLEAN_BRANCHES = "apply_constraints_to_boolean_branches" + """Apply constraints to simplify boolean branches. Default: False""" + + TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_PROVE_CONDITIONAL = "propagate_knowns_to_prove_conditional" + """Propagate known values to prove conditionals. Default: False""" + + TL_SIMPLIFY_PROPAGATE_KNOWNS_TO_SIMPLIFY_EXPRESSIONS = "propagate_knowns_to_simplify_expressions" + """Propagate known values to simplify expressions. Default: False""" + + TL_SIMPLIFY_ENABLE_LET_INLINE = "enable_simplify_let_inline" + """Enable inlining of let statements during simplification. Default: True""" TL_DISABLE_DATA_RACE_CHECK = "tl.disable_data_race_check" """Disable data race check in TileLang. Default: False"""