Skip to content

Commit

Permalink
refactor interpreter; fix C division bug; fix some of the failing tes…
Browse files Browse the repository at this point in the history
…ts; add 'built-in' test
  • Loading branch information
andrewdalex committed Nov 6, 2024
1 parent 27ae79e commit dd8784c
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 154 deletions.
140 changes: 76 additions & 64 deletions src/exo/LoopIR_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,64 +35,16 @@ def run_interpreter(proc, kwargs):
Interpreter(proc, kwargs)


# context is global
ctxt = defaultdict(dict)

class Interpreter:
def __init__(self, proc, kwargs, use_randomization=False):
assert isinstance(proc, LoopIR.proc)

proc = ParallelAnalysis().run(proc)
proc = PrecisionAnalysis().run(proc) # TODO: need this?
proc = WindowAnalysis().apply_proc(proc)
proc = MemoryAnalysis().run(proc) # TODO: need this?
if not isinstance(proc, LoopIR.proc):
raise TypeError(f"Expected {proc.name} to be of type proc")

self.proc = proc
self.env = ChainMap()
self.use_randomization = use_randomization
self.ctxt = defaultdict(dict)

# type check args
for a in proc.args:
if not str(a.name) in kwargs:
raise TypeError(f"expected argument '{a.name}' to be supplied")

if a.type is T.size:
if not is_pos_int(kwargs[str(a.name)]):
raise TypeError(
f"expected size '{a.name}' to have positive integer value"
)
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.index:
if type(kwargs[str(a.name)]) is not int:
raise TypeError(
f"expected index variable '{a.name}' to be an integer"
)
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.bool:
if type(kwargs[str(a.name)]) is not bool:
raise TypeError(f"expected bool variable '{a.name}' to be a bool")
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.stride:
if type(kwargs[str(a.name)]) is not int:
raise TypeError(
f"expected stride variable '{a.name}' to be an integer"
)
self.env[a.name] = kwargs[str(a.name)]
else:
self.typecheck_input_buffer(a, kwargs)
self.env[a.name] = kwargs[str(a.name)]

# evaluate preconditions
for pred in proc.preds:
if isinstance(pred, LoopIR.Const):
continue
else:
assert self.eval_e(pred), "precondition not satisfied"

# eval statements
self.env = self.env.new_child()
self.eval_stmts(proc.body)
self.env = self.env.parents
self.eval_proc(proc, kwargs)

def _new_scope(self):
self.env = self.env.new_child()
Expand Down Expand Up @@ -154,14 +106,60 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
f"but got shape {tuple(buf.shape)}"
)

def eval_proc(self, proc, kwargs):
proc = ParallelAnalysis().run(proc)
proc = PrecisionAnalysis().run(proc) # TODO: need this?
proc = WindowAnalysis().apply_proc(proc)
proc = MemoryAnalysis().run(proc) # TODO: need this?

for a in proc.args:
if not str(a.name) in kwargs:
raise TypeError(f"expected argument '{a.name}' to be supplied")

if a.type is T.size:
if not is_pos_int(kwargs[str(a.name)]):
raise TypeError(
f"expected size '{a.name}' to have positive integer value"
)
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.index:
if type(kwargs[str(a.name)]) is not int:
raise TypeError(
f"expected index variable '{a.name}' to be an integer"
)
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.bool:
if type(kwargs[str(a.name)]) is not bool:
raise TypeError(f"expected bool variable '{a.name}' to be a bool")
self.env[a.name] = kwargs[str(a.name)]
elif a.type is T.stride:
if type(kwargs[str(a.name)]) is not int:
raise TypeError(
f"expected stride variable '{a.name}' to be an integer"
)
self.env[a.name] = kwargs[str(a.name)]
else:
self.typecheck_input_buffer(a, kwargs)
self.env[a.name] = kwargs[str(a.name)]

# evaluate preconditions
for pred in proc.preds:
if isinstance(pred, LoopIR.Const):
continue
else:
assert self.eval_e(pred), "precondition not satisfied"

# eval statements
self.eval_stmts(proc.body)

def eval_stmts(self, stmts):
for s in stmts:
self.eval_s(s)

def eval_s(self, s):
if isinstance(s, LoopIR.Pass):
pass

elif isinstance(s, (LoopIR.Assign, LoopIR.Reduce)):
lbuf = self.env[s.name]
if len(s.idx) == 0:
Expand All @@ -179,12 +177,14 @@ def eval_s(self, s):
elif isinstance(s, LoopIR.WriteConfig):
nm = s.config.name()
rhs = self.eval_e(s.rhs)
ctxt[nm][s.field] = rhs
self.ctxt[nm][s.field] = rhs

elif isinstance(s, LoopIR.WindowStmt):
# nm = rbuf[...]
assert s.name not in self.env, "WindowStmt should be a fresh assignment"
assert isinstance(s.rhs, LoopIR.WindowExpr), "WindowStmt rhs should be WindowExpr"
assert isinstance(
s.rhs, LoopIR.WindowExpr
), "WindowStmt rhs should be WindowExpr"
self.env[s.name] = self.eval_e(s.rhs)

elif isinstance(s, LoopIR.If):
Expand Down Expand Up @@ -225,7 +225,9 @@ def eval_s(self, s):
argvals = [self.eval_e(a, call_arg=True) for a in s.args]
argnames = [str(a.name) for a in s.f.args]
kwargs = {nm: val for nm, val in zip(argnames, argvals)}
Interpreter(s.f, kwargs, use_randomization=self.use_randomization)
self._new_scope()
self.eval_proc(s.f, kwargs)
self._del_scope()

else:
assert False, "bad statement case"
Expand Down Expand Up @@ -253,10 +255,14 @@ def stringify_w_access(a):
assert False, "bad w_access case"

# hack to handle interval indexes: LoopIR.Interval returns a string representing the interval
idx = ("0",) if len(e.idx) == 0 else tuple(stringify_w_access(a) for a in e.idx)
idx = (
("0",)
if len(e.idx) == 0
else tuple(stringify_w_access(a) for a in e.idx)
)
res = eval(f"buf[{','.join(idx)}]")
return res

elif isinstance(e, LoopIR.Const):
return e.val

Expand All @@ -268,9 +274,12 @@ def stringify_w_access(a):
return lhs - rhs
elif e.op == "*":
return lhs * rhs
elif e.op == "/": # is this right?
if isinstance(lhs, int):
return (lhs + rhs - 1) // rhs
elif e.op == "/":
if isinstance(lhs, int) and isinstance(rhs, int):
# this is what was here before and without the rhs check
# counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl
# return (lhs + rhs - 1) // rhs
return int(lhs / rhs)
else:
return lhs / rhs
elif e.op == "%":
Expand All @@ -293,9 +302,12 @@ def stringify_w_access(a):
elif isinstance(e, LoopIR.USub):
return -self.eval_e(e.arg)

# BuiltIns don't go to the interpreter, they are just called (via call) like a proc
# TODO Discuss to make sure
elif isinstance(e, LoopIR.BuiltIn):
args = [self.eval_e(a) for a in e.args]
return e.f.interpret(args)
assert False, "Not implemented"
# args = [self.eval_e(a) for a in e.args]
# return e.f.interpret(args)

elif isinstance(e, LoopIR.StrideExpr):
buf = self.env[e.name]
Expand All @@ -305,7 +317,7 @@ def stringify_w_access(a):

elif isinstance(e, LoopIR.ReadConfig):
nm = e.config.name()
return ctxt[nm][e.field]
return self.ctxt[nm][e.field]

else:
print(e)
Expand Down
Loading

0 comments on commit dd8784c

Please sign in to comment.