Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix const-window argument passing #288

Merged
merged 4 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 66 additions & 48 deletions src/exo/LoopIR_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ def run_compile(proc_list, h_file_name: str):
lib_name = sanitize_str(file_stem)
fwd_decls, body = compile_to_strings(lib_name, proc_list)

body = f'#include "{h_file_name}"\n\n{body}'
source = f'#include "{h_file_name}"\n\n{body}'

header_guard = f"{lib_name}_H".upper()
fwd_decls = f"""
header = f"""
#pragma once
#ifndef {header_guard}
#define {header_guard}
Expand All @@ -266,7 +266,7 @@ def run_compile(proc_list, h_file_name: str):
#endif // {header_guard}
"""

return body, fwd_decls
return source, header


_static_helpers = {
Expand Down Expand Up @@ -608,19 +608,21 @@ def get_idx_offset(self, name, typ, idx):
acc = " + ".join([f"({i}) * ({s})" for i, s in zip(idx, strides)])
return acc

def get_window_type(self, typ):
def get_window_type(self, typ, is_const=None):
assert isinstance(typ, T.Window) or (
isinstance(typ, LoopIR.fnarg) and typ.type.is_win()
)

if isinstance(typ, T.Window):
base = typ.as_tensor.basetype()
n_dims = len(typ.as_tensor.shape())
is_const = typ.src_buf not in self.non_const
if is_const is None:
is_const = typ.src_buf not in self.non_const
else:
base = typ.type.basetype()
n_dims = len(typ.type.shape())
is_const = typ.name not in self.non_const
if is_const is None:
is_const = typ.name not in self.non_const

win = window_struct(base, n_dims, is_const)
self.window_defns.add(win)
Expand Down Expand Up @@ -729,7 +731,7 @@ def comp_s(self, s):
assert all(
a.type.is_win() == fna.type.is_win() for a, fna in zip(s.args, s.f.args)
)
args = [self.comp_e(e, call_arg=True) for e in s.args]
args = [self.comp_fnarg(e, s.f, i) for i, e in enumerate(s.args)]
if s.f.instr is not None:
d = dict()
assert len(s.f.args) == len(args)
Expand All @@ -753,53 +755,69 @@ def comp_s(self, s):
else:
assert False, "bad case"

def comp_e(self, e, prec=0, call_arg=False):
etyp = type(e)

if etyp is LoopIR.Read:
def comp_fnarg(self, e, fn, i, *, prec=0):
if isinstance(e, LoopIR.Read):
assert not e.idx
rtyp = self.envtyp[e.name]
if call_arg:
assert len(e.idx) == 0
if rtyp.is_indexable():
return self.env[e.name]
elif rtyp is T.bool:
return self.env[e.name]
elif rtyp is T.stride:
return self.env[e.name]
elif e.name in self._scalar_refs:
return self.env[e.name]
elif rtyp.is_tensor_or_window():
return self.env[e.name]
else:
assert rtyp.is_real_scalar()
return f"&{self.env[e.name]}"
if rtyp.is_indexable():
return self.env[e.name]
elif rtyp is T.bool:
return self.env[e.name]
elif rtyp is T.stride:
return self.env[e.name]
elif e.name in self._scalar_refs:
return self.env[e.name]
elif rtyp.is_tensor_or_window():
return self.env[e.name]
else:
assert rtyp.is_real_scalar()
return f"&{self.env[e.name]}"
elif isinstance(e, LoopIR.WindowExpr):
if isinstance(fn, LoopIR.proc):
callee_buf = fn.args[i].name
is_const = callee_buf not in set(
x.buffer for x in fn.eff.writes + fn.eff.reduces
)
else:
if rtyp.is_indexable() or rtyp is T.bool or rtyp == T.stride:
return self.env[e.name]
raise NotImplementedError("Passing windows to built-ins")
win_struct = self.get_window_type(e.type, is_const)
data, strides = self.window_struct_fields(e)
return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}"
else:
return self.comp_e(e, prec)

mem: Memory = self.mems[e.name]
def comp_e(self, e, prec=0):
if isinstance(e, LoopIR.Read):
rtyp = self.envtyp[e.name]
if rtyp.is_indexable() or rtyp is T.bool or rtyp == T.stride:
return self.env[e.name]

if not mem.can_read():
raise MemGenError(
f"{e.srcinfo}: cannot read from buffer "
f"'{e.name}' in memory '{mem.name()}'"
)
mem: Memory = self.mems[e.name]

if e.name in self._scalar_refs:
return f"*{self.env[e.name]}"
elif not rtyp.is_tensor_or_window():
return self.env[e.name]
else:
return self.access_str(e.name, e.idx)
elif etyp is LoopIR.WindowExpr:
if not mem.can_read():
raise MemGenError(
f"{e.srcinfo}: cannot read from buffer "
f"'{e.name}' in memory '{mem.name()}'"
)

if e.name in self._scalar_refs:
return f"*{self.env[e.name]}"
elif not rtyp.is_tensor_or_window():
return self.env[e.name]
else:
return self.access_str(e.name, e.idx)

elif isinstance(e, LoopIR.WindowExpr):
win_struct = self.get_window_type(e.type)
data, strides = self.window_struct_fields(e)
return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}"
elif etyp is LoopIR.Const:

elif isinstance(e, LoopIR.Const):
if isinstance(e.val, bool):
return "true" if e.val else "false"
return str(e.val)
elif etyp is LoopIR.BinOp:

elif isinstance(e, LoopIR.BinOp):
local_prec = op_prec[e.op]
int_div = e.op == "/" and not e.type.is_numeric()
if int_div:
Expand All @@ -824,19 +842,19 @@ def comp_e(self, e, prec=0, call_arg=False):
s = f"({s})"

return s
elif etyp is LoopIR.USub:
elif isinstance(e, LoopIR.USub):
return f'-{self.comp_e(e.arg, op_prec["~"])}'

elif etyp is LoopIR.BuiltIn:
args = [self.comp_e(a, call_arg=True) for a in e.args]
elif isinstance(e, LoopIR.BuiltIn):
args = [self.comp_fnarg(a, e, i) for i, a in enumerate(e.args)]
return e.f.compile(args)

elif etyp is LoopIR.StrideExpr:
elif isinstance(e, LoopIR.StrideExpr):
basetyp = self.envtyp[e.name]
strides = self.get_strides(e.name, basetyp)

return strides[e.dim]
elif etyp is LoopIR.ReadConfig:
elif isinstance(e, LoopIR.ReadConfig):
if not e.config.is_allow_rw():
raise ConfigError(
f"{e.srcinfo}: cannot read from config '{e.config.name()}'"
Expand Down
8 changes: 6 additions & 2 deletions tests/golden/pldi22/test_gemmini_conv_ae/test_conv_ae.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ typedef struct c_code_str_Context {
} ConfigStore;

} c_code_str_Context;
struct exo_win_2i32{
int32_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_2i32c{
const int32_t * const data;
const int_fast32_t strides[2];
Expand All @@ -53,8 +57,8 @@ struct exo_win_2i8c{
const int8_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_3i8c{
const int8_t * const data;
struct exo_win_3i8{
int8_t * const data;
const int_fast32_t strides[3];
};
// conv_on_gemmini(
Expand Down
8 changes: 6 additions & 2 deletions tests/golden/pldi22/test_gemmini_matmul_ae/test_matmul_ae.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ typedef struct c_code_str_Context {
} ConfigStore;

} c_code_str_Context;
struct exo_win_2i32{
int32_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_2i32c{
const int32_t * const data;
const int_fast32_t strides[2];
Expand All @@ -57,8 +61,8 @@ struct exo_win_2i8c{
const int8_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_3i8c{
const int8_t * const data;
struct exo_win_3i8{
int8_t * const data;
const int_fast32_t strides[3];
};
// matmul_on_gemmini(
Expand Down
8 changes: 6 additions & 2 deletions tests/golden/test_apps/test_gemmini_conv.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ typedef struct test_case_Context {
} ConfigStore;

} test_case_Context;
struct exo_win_2i32{
int32_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_2i32c{
const int32_t * const data;
const int_fast32_t strides[2];
Expand All @@ -62,8 +66,8 @@ struct exo_win_2i8c{
const int8_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_3i8c{
const int8_t * const data;
struct exo_win_3i8{
int8_t * const data;
const int_fast32_t strides[3];
};
// conv_17(
Expand Down
8 changes: 6 additions & 2 deletions tests/golden/test_apps/test_gemmini_matmul.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ typedef struct test_case_Context {
} ConfigStore;

} test_case_Context;
struct exo_win_2i32{
int32_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_2i32c{
const int32_t * const data;
const int_fast32_t strides[2];
Expand All @@ -66,8 +70,8 @@ struct exo_win_2i8c{
const int8_t * const data;
const int_fast32_t strides[2];
};
struct exo_win_3i8c{
const int8_t * const data;
struct exo_win_3i8{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test and the others, these structs aren't actually used anywhere. I'm guessing they come from instr signatures, but I'm not totally sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, was there any warning or performance consequence from not having const?

Copy link
Contributor Author

@alexreinking alexreinking Nov 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing we rigorously measured. But there are still a great many reasons for having const.

  1. Having const enables the C compiler to check our generated code. If we misplace a const the C compiler will warn or error, and that means our analysis somewhere is wrong. This helps us catch bugs.
  2. Having accurate const information makes it easier to integrate into existing C codebases. For instance, someone might have a const-qualified buffer in their program. They would have to cast away that const to call an Exo function. That's a recipe for unexpected / undefined behavior.

and so on...

int8_t * const data;
const int_fast32_t strides[3];
};
// cpu_matmul_14(
Expand Down
80 changes: 80 additions & 0 deletions tests/golden/test_codegen/test_const_local_buffer.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

#pragma once
#ifndef TEST_H
#define TEST_H

#ifdef __cplusplus
extern "C" {
#endif


#include <stdint.h>
#include <stdbool.h>

// Compiler feature macros adapted from Hedley (public domain)
// https://github.com/nemequ/hedley

#if defined(__has_builtin)
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin)
#else
# define EXO_HAS_BUILTIN(builtin) (0)
#endif

#if EXO_HAS_BUILTIN(__builtin_assume)
# define EXO_ASSUME(expr) __builtin_assume(expr)
#elif EXO_HAS_BUILTIN(__builtin_unreachable)
# define EXO_ASSUME(expr) \
((void)((expr) ? 1 : (__builtin_unreachable(), 1)))
#else
# define EXO_ASSUME(expr) ((void)(expr))
#endif


struct exo_win_1f32{
float * const data;
const int_fast32_t strides[1];
};
// caller(

// )
void caller( void *ctxt );



#ifdef __cplusplus
}
#endif
#endif // TEST_H
#include "test.h"



#include <stdio.h>
#include <stdlib.h>


// callee(
// N : size,
// A : [f32][N] @DRAM
// )
static void callee( void *ctxt, int_fast32_t N, struct exo_win_1f32 A );

// callee(
// N : size,
// A : [f32][N] @DRAM
// )
static void callee( void *ctxt, int_fast32_t N, struct exo_win_1f32 A ) {
for (int i = 0; i < N; i++) {
A.data[(i) * (A.strides[0])] = 0.0;
}
}

// caller(

// )
void caller( void *ctxt ) {
float *A = malloc(10 * sizeof(*A));
callee(ctxt,10,(struct exo_win_1f32){ &A[(0) * (1)], { 1 } });
free(A);
}

Loading