Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
// First convert a < b into a - b < 0
PrimExpr expr = this->CanonicalMutate(op->a - op->b);
// Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., s{n-1}, c)
// 1. if can prove -d < xn < d, then we can simplify
// 1. if can prove 0 <= xn < d, then we can simplify
// the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * (s{n-1}/d) < c/d,
// e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 2`
// 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d % (m/d)
Expand All @@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
ICHECK(extra->dtype == dtype);
PrimExpr normal_extra = extra->Normalize();
if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) &&
this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) {
// Case 1. -d < xn < d
this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) {
// Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
} else if (extra->args.size() == 1 &&
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/pack_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ enum ArgConvertCode {
};

inline ArgConvertCode GetArgConvertCode(DLDataType t) {
ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now";
ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now";
if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
Expand Down
1 change: 0 additions & 1 deletion tests/python/arith/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ def test_simplify_le():
ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))

ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
ck.verify(x * 8 + y - z < 16, x < 2)

n = te.size_var("n")
ck.verify(x * 8 + y < n, x * 8 + y < n)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/dlight/test_gpu_low_batch_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T
with T.block("NT_matmul_intermediate_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1])
T.writes(NT_matmul_intermediate[v0, T.int64(0), v1])
NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
Expand Down Expand Up @@ -240,7 +240,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float
with T.block("NT_matmul_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size)
T.reads(NT_matmul_pad_local[v0, T.int64(0), v1])
T.writes(NT_matmul[v0, T.int64(0), v1])
NT_matmul[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1]
Expand Down Expand Up @@ -369,7 +369,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"
with T.block("C_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8))
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8))
T.reads(C_pad_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_pad_local[v0, v1]
Expand Down Expand Up @@ -516,7 +516,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo
with T.block("C_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1)
T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size)
T.reads(C_pad_local[v0, 0, v1])
T.writes(C[v0, 0, v1])
C[v0, 0, v1] = C_pad_local[v0, 0, v1]
Expand Down
4 changes: 2 additions & 2 deletions tests/python/dlight/test_gpu_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m)
T.reads(matmul_pad_local[v0, v1, v2])
T.writes(matmul[v0, v1, v2])
matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2]
Expand Down Expand Up @@ -835,7 +835,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len)
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
Expand Down
Loading