diff --git a/toonygrad/engine/realize.py b/toonygrad/engine/realize.py index 6f03147..1fb7f1f 100644 --- a/toonygrad/engine/realize.py +++ b/toonygrad/engine/realize.py @@ -8,30 +8,6 @@ from toonygrad.codegen.linearize import linearize_uop from toonygrad.codegen.uopgraph import full_graph_rewrite from toonygrad.renderer import Renderer - -acc_number = 0 -def do_reduce(root:UOp): - global acc_number - reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents) - ret = root.src[0] - if len(reduce_parented): - acc = UOp(UOps.DEFINE_ACC, root.dtype, - (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,)) - acc_number += 1 - ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret))) - # for MAX, we can just ignore the unparented - if root.arg is BinaryOps.ADD: - for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) - return ret - -no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR), - name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)]) - -just_reduce = PatternMatcher([ - # do reduce - (UPat(UOps.REDUCE, name="root"), do_reduce), -]) - from toonygrad.codegen.kernel import Kernel @track_rewrites diff --git a/toonygrad/engine/schedule.py b/toonygrad/engine/schedule.py index e878461..ed328a6 100644 --- a/toonygrad/engine/schedule.py +++ b/toonygrad/engine/schedule.py @@ -1,6 +1,6 @@ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Optional from dataclasses import dataclass -from toonygrad.ops import UOp, graph_rewrite, PatternMatcher, UPat, UOps, symbolic, track_rewrites +from toonygrad.ops import UOp, graph_rewrite, PatternMatcher, UPat, UOps, symbolic, track_rewrites, resolve from toonygrad.engine.lazy import LazyBuffer from toonygrad.shape.symbolic import Variable from toonygrad.shape.shapetracker import ShapeTracker @@ -44,11 +44,20 @@ def append_kernel(k:List[UOp], base:UOp): k.append(base.sink()) (UPat(UOps.LOAD, src=(UPat(), UPat(), UPat()), name="ld"), lambda k,ld: UOp.load(ld.src[0], ld.src[1], dtype=ld.dtype)), ]) -def append_buffer(b:List[Buffer], base:UOp): - if base.buffer not in b: b.append(base.buffer) +def index_buffer(b:List[Buffer], buf:UOp, view:UOp, store:Optional[UOp]=None) -> UOp: + if buf.buffer not in b: b.append(buf.buffer) # should this be the ptr, or the buffer? - return UOp(UOps.DEFINE_GLOBAL, base.dtype.ptr(), (), b.index(base.buffer)) -enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="base"), append_buffer)]) + idx, mask = view.st.to_indexed_uops() + ret = UOp(UOps.DEFINE_GLOBAL, buf.dtype.ptr(), (), b.index(buf.buffer)) + if resolve(mask != True): ret = mask.where(ret, ret.const_like(0)) + ret = UOp(UOps.INDEX, buf.dtype.ptr(), (ret, idx)) + return UOp.store(ret, store) if store is not None else UOp.load(ret, dtype=buf.dtype) + +enumerate_bufs = PatternMatcher([ + #(UPat(UOps.BUFFER, name="base"), append_buffer), + (UPat(UOps.LOAD, src=(UPat(UOps.BUFFER, name="buf"), UPat(UOps.VIEW, name="view"))), index_buffer), + (UPat(UOps.STORE, src=(UPat(UOps.BUFFER, name="buf"), UPat(UOps.VIEW, name="view"), UPat.var("store"))), index_buffer), +]) @track_rewrites def _schedule_rewrite(sink:UOp) -> List[ScheduleItem]: diff --git a/toonygrad/ops.py b/toonygrad/ops.py index 30e077e..2dbbad8 100644 --- a/toonygrad/ops.py +++ b/toonygrad/ops.py @@ -106,6 +106,7 @@ class UOps(FastEnum): EMPTY = auto() BUFFER_VIEW = auto() + INDEX = auto() EXPAND = auto() CONTRACT = auto() VIEW = auto()