Skip to content

Commit

Permalink
cond.not.where(T, F) -> cond.where(F, T)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 1, 2024
1 parent 509c4a5 commit a93bed7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 7 additions & 1 deletion test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def test_div_neg_all_range(self):
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)")

# NOTE: tests are not correct in symbolic
def test_div_neg_then_neg(self):
# taken from arange opts
lidx0 = Variable("lidx0", 0, 7)
Expand Down Expand Up @@ -492,6 +491,13 @@ def test_where_removal(self):
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")

def test_where_flip(self):
cond = Variable("a", 0, 3).lt(2).logical_not()
t, f = Variable("t", 0, 3), Variable("f", 0, 3)
self.helper_test_variable(cond, 0, 1, "((a<2)!=True)")
self.helper_test_variable(cond.where(t, f), 0, 3, "(f if (a<2) else t)")
self.helper_test_variable(cond.where(f, t), 0, 3, "(t if (a<2) else f)")

class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10
Expand Down
8 changes: 6 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,13 @@ def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x,
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
# a conditional with the same results either way is a noop, also fold const conditionals
# ** where **
# a conditional with the same results either way is a noop
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# fold const conditionals
(UPat.cvar("cond", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda cond,c0,c1: c0 if cond.arg else c1),
# flip if cond has a not
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("c0"), UPat.var("c1")), lambda cond,c0,c1: cond.where(c1, c0)),
# ALU min==max -> CONST (slow!)
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
Expand Down

0 comments on commit a93bed7

Please sign in to comment.