Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check generated C code from apps #230

Merged
merged 9 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import re
import types
from pathlib import Path
from typing import Optional, Union, List

from .API_types import ProcedureBase
Expand Down Expand Up @@ -131,10 +132,16 @@ def do_s(self, s):
# Procedure Objects


def compile_procs(proc_list, path, c_file, h_file):
def compile_procs(proc_list, basedir: Path, c_file: str, h_file: str):
c_data, h_data = compile_procs_to_strings(proc_list, h_file)
(basedir / c_file).write_text(c_data)
(basedir / h_file).write_text(h_data)


def compile_procs_to_strings(proc_list, h_file_name: str):
assert isinstance(proc_list, list)
assert all(isinstance(p, Procedure) for p in proc_list)
run_compile([p._loopir_proc for p in proc_list], path, c_file, h_file)
return run_compile([p._loopir_proc for p in proc_list], h_file_name)


class Procedure(ProcedureBase):
Expand Down Expand Up @@ -315,9 +322,8 @@ def c_code_str(self):
decls, defns = compile_to_strings("c_code_str", [self._loopir_proc])
return decls + '\n' + defns

def compile_c(self, directory, filename):
run_compile([self._loopir_proc], directory,
(filename + ".c"), (filename + ".h"))
def compile_c(self, directory: Path, filename: str):
compile_procs([self._loopir_proc], directory, f'{filename}.c', f'{filename}.h')

def interpret(self, **kwargs):
run_interpreter(self._loopir_proc, kwargs)
Expand Down
189 changes: 103 additions & 86 deletions src/exo/LoopIR_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from collections import ChainMap
from collections import defaultdict
from pathlib import Path

from .LoopIR import LoopIR, LoopIR_Do
from .LoopIR import T
Expand Down Expand Up @@ -219,17 +220,14 @@ def window_struct(basetyp, n_dims):
# top level compiler function called by tests!


def run_compile(proc_list, path, c_file, h_file, header_guard=None):
file_stem = re.match(r'^([^\.]+)\.[^\.]+$', c_file)
if not file_stem:
raise ValueError("Expected file name to end "
"with extension: e.g. ___.__ ")
lib_name = sanitize_str(file_stem[1])
def run_compile(proc_list, h_file_name: str):
file_stem = str(Path(h_file_name).stem)
lib_name = sanitize_str(file_stem)
fwd_decls, body = compile_to_strings(lib_name, proc_list)

if header_guard is None:
header_guard = re.sub(r'\W', '_', h_file).upper()
body = f'#include "{h_file_name}"\n\n{body}'

header_guard = f'{lib_name}_H'.upper()
fwd_decls = f'''
#pragma once
#ifndef {header_guard}
Expand All @@ -247,103 +245,68 @@ def run_compile(proc_list, path, c_file, h_file, header_guard=None):
#endif // {header_guard}
'''

body = f'#include "{h_file}"\n\n{body}'

with open(os.path.join(path, h_file), "w") as f_header:
f_header.write(fwd_decls)

with open(os.path.join(path, c_file), "w") as f_cpp:
f_cpp.write(body)
return body, fwd_decls


def compile_to_strings(lib_name, proc_list):
# get transitive closure of call-graph
# Get transitive closure of call-graph
orig_procs = [id(p) for p in proc_list]
proc_list = find_all_subprocs(proc_list)
mem_list = find_all_mems(proc_list)
builtin_list = find_all_builtins(proc_list)
config_list = find_all_configs(proc_list)

# check for name conflicts between procs
used_names = set()
for p in proc_list:
if p.name in used_names:
raise Exception(f"Cannot compile multiple "
f"procedures named '{p.name}'")
used_names.add(p.name)

body = [
"static int _floor_div(int num, int quot) {",
" int off = (num>=0)? 0 : quot-1;",
" return (num-off)/quot;",
"}",
"",
"static int8_t _clamp_32to8(int32_t x) {",
" return (x < -128)? -128 : ((x > 127)? 127 : x);",
"}",
"",
]
# Determine name for library context struct
ctxt_name = f"{lib_name}_Context"

fwd_decls = []
struct_defns = set()
def from_lines(x):
return '\n'.join(x)

m: Memory
for m in mem_list:
body.append(f'{m.global_()}\n')
proc_list = list(sorted(find_all_subprocs(proc_list), key=lambda x: x.name))

for b in builtin_list:
glb = b.globl()
if glb:
body.append(glb)
body.append("\n")
# Header contents
ctxt_def = _compile_context_struct(find_all_configs(proc_list), ctxt_name)
struct_defns = set()
fwd_decls = []

# Build Context Struct
ctxt_name = f"{lib_name}_Context"
ctxt_def = [f"typedef struct {ctxt_name} {{ ",
f""]
for c in config_list:
if c.is_allow_rw():
sdef_lines = c.c_struct_def()
sdef_lines = [f" {line}" for line in sdef_lines]
ctxt_def += sdef_lines
ctxt_def += [""]
else:
ctxt_def += [f"// config '{c.name()}' not materialized",
""]
ctxt_def += [f"}} {ctxt_name};"]
fwd_decls += ctxt_def
fwd_decls.append("\n")
# check that we don't have a name conflict on configs
config_names = {c.name() for c in config_list}
if len(config_names) != len(config_list):
raise TypeError("Cannot compile while using two configs "
"with the same name")
# Body contents
memory_code = _compile_memories(find_all_mems(proc_list))
builtin_code = _compile_builtins(find_all_builtins(proc_list))
priv_decls = []
proc_bodies = []

# Compile proc bodies
seen_procs = set()
for p in proc_list:
# don't compile instruction procedures, but add a comment?
if p.name in seen_procs:
raise TypeError(f"multiple procs named {p.name}")
seen_procs.add(p.name)

# don't compile instruction procedures, but add a comment.
if p.instr is not None:
argstr = ','.join([str(a.name) for a in p.args])
body.append("\n/* relying on the following instruction...\n"
f"{p.name}({argstr})\n"
f'{p.instr}\n'
"*/\n")
proc_bodies.extend([
'',
'/* relying on the following instruction..."',
f"{p.name}({argstr})",
p.instr,
"*/",
])
else:
p_to_start = p
orig_p = id(p)
p = PrecisionAnalysis(p).result()
p = WindowAnalysis(p).result()
p = MemoryAnalysis(p).result()
comp = Compiler(p, ctxt_name)
d, b = comp.comp_top()
struct_defns = struct_defns.union(comp.struct_defns())
struct_defns |= comp.struct_defns()
# only dump .h-file forward declarations for requested procedures
if id(p_to_start) in orig_procs:
if orig_p in orig_procs:
fwd_decls.append(d)
body.append(b)
else:
priv_decls.append(d)
proc_bodies.append(b)

# add struct definitions before the other forward declarations
fwd_decls = list(struct_defns) + fwd_decls
fwd_decls = '\n'.join(fwd_decls)
fwd_decls = f'''
# Structs are just blobs of code... still sort them for output stability
struct_defns = list(sorted(struct_defns))

header_contents = f'''
#include <stdint.h>
#include <stdbool.h>

Expand All @@ -365,12 +328,66 @@ def compile_to_strings(lib_name, proc_list):
# define EXO_ASSUME(expr) ((void)(expr))
#endif

{fwd_decls}
{from_lines(ctxt_def)}
{from_lines(struct_defns)}
{from_lines(fwd_decls)}
'''

body = '\n'.join(body)
body_contents = f'''
static int _floor_div(int num, int quot) {{
int off = (num>=0)? 0 : quot-1;
return (num-off)/quot;
}}

static int8_t _clamp_32to8(int32_t x) {{
return (x < -128)? -128 : ((x > 127)? 127 : x);
}}

{from_lines(memory_code)}
{from_lines(builtin_code)}
{from_lines(priv_decls)}
{from_lines(proc_bodies)}
'''

return header_contents, body_contents


def _compile_builtins(builtins):
builtin_code = []
for b in sorted(builtins, key=lambda x: x.name()):
if glb := b.globl():
builtin_code.append(glb)
return builtin_code


def _compile_memories(mems):
memory_code = []
for m in sorted(mems, key=lambda x: x.name()):
memory_code.append(m.global_())
return memory_code


def _compile_context_struct(configs, ctxt_name):
ctxt_def = [f"typedef struct {ctxt_name} {{ ",
f""]
seen = set()
for c in sorted(configs, key=lambda x: x.name()):
name = c.name()

return fwd_decls, body
if name in seen:
raise TypeError(f"multiple configs named {name}")
seen.add(name)

if c.is_allow_rw():
sdef_lines = c.c_struct_def()
sdef_lines = [f" {line}" for line in sdef_lines]
ctxt_def += sdef_lines
ctxt_def += [""]
else:
ctxt_def += [f"// config '{name}' not materialized",
""]
ctxt_def += [f"}} {ctxt_name};"]
return ctxt_def


# --------------------------------------------------------------------------- #
Expand Down
36 changes: 19 additions & 17 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,35 +2848,37 @@ def has_div_mod_config(self, e):
def index_start(self, e):
assert isinstance(e, LoopIR.expr)
# Div and mod need more subtle handling. Don't normalize for now.
# Skip ReadConfigs, they needs careful handling because they're not Sym.
# Skip ReadConfigs, they need careful handling because they're not Sym.
if self.has_div_mod_config(e):
return e

# Make a map of symbols and coefficients
n_map = self.normalize_e(e)

# Write back to LoopIR.expr
def get_loopir(key, value):
vconst = LoopIR.Const(value, T.int, e.srcinfo)
if key == self.C:
return vconst
else:
readkey = LoopIR.Read(key, [], e.type, e.srcinfo)
return LoopIR.BinOp('*', vconst, readkey, e.type, e.srcinfo)
def scale_read(coeff, key):
return LoopIR.BinOp(
'*',
LoopIR.Const(coeff, T.int, e.srcinfo),
LoopIR.Read(key, [], e.type, e.srcinfo),
e.type,
e.srcinfo
)

delete_zero = { key: n_map[key] for key in n_map if n_map[key] != 0 }
new_e = LoopIR.Const(0, T.int, e.srcinfo)
for key, val in delete_zero.items():
if val > 0:
# add
new_e = LoopIR.BinOp('+', new_e, get_loopir(key, val), e.type, e.srcinfo)
new_e = LoopIR.Const(n_map.get(self.C, 0), T.int, e.srcinfo)

delete_zero = [(n_map[v], v)
for v in n_map
if v != self.C and n_map[v] != 0]

for coeff, v in sorted(delete_zero):
if coeff > 0:
new_e = LoopIR.BinOp('+', new_e, scale_read(coeff, v), e.type, e.srcinfo)
else:
# sub
new_e = LoopIR.BinOp('-', new_e, get_loopir(key, -val), e.type, e.srcinfo)
new_e = LoopIR.BinOp('-', new_e, scale_read(-coeff, v), e.type, e.srcinfo)

return new_e


def map_e(self, e):
if e.type.is_indexable():
return self.index_start(e)
Expand Down
3 changes: 2 additions & 1 deletion src/exo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .API import Procedure, compile_procs, proc, instr, config
from .API import Procedure, compile_procs, compile_procs_to_strings, proc, instr, config
from .LoopIR_scheduling import SchedulingError
from .parse_fragment import ParseFragmentError
from .configs import Config
Expand All @@ -12,6 +12,7 @@
__all__ = [
"Procedure",
"compile_procs",
"compile_procs_to_strings",
"proc",
"instr",
"config",
Expand Down
7 changes: 7 additions & 0 deletions src/exo/pattern_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ def match_e(self, pat, e):
if isinstance(pat, PAST.E_Hole):
return True

# Special case: -3 can be parsed as USub(Const(3))... it should match Const(-3)
if (isinstance(pat, PAST.USub)
and isinstance(pat.arg, PAST.Const)
and isinstance(e, LoopIR.Const)):
pat = pat.arg.update(val=-pat.arg.val)

# first ensure that the pattern and statement
# are the same constructor
if not isinstance(e, (LoopIR.WindowExpr,) + tuple(_PAST_to_LoopIR[type(pat)])):
Expand All @@ -290,6 +296,7 @@ def match_e(self, pat, e):
elif isinstance(e, LoopIR.Const):
return pat.val == e.val
elif isinstance(e, LoopIR.BinOp):
# TODO: do we need to handle associativity? (a + b) + c vs a + (b + c)?
return ( pat.op == e.op and
self.match_e(pat.lhs, e.lhs) and
self.match_e(pat.rhs, e.rhs) )
Expand Down
Loading