diff --git a/src/exo/API.py b/src/exo/API.py index 4db9bcc47..961aad648 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -2,6 +2,7 @@ import inspect import re import types +from pathlib import Path from typing import Optional, Union, List from .API_types import ProcedureBase @@ -131,10 +132,16 @@ def do_s(self, s): # Procedure Objects -def compile_procs(proc_list, path, c_file, h_file): +def compile_procs(proc_list, basedir: Path, c_file: str, h_file: str): + c_data, h_data = compile_procs_to_strings(proc_list, h_file) + (basedir / c_file).write_text(c_data) + (basedir / h_file).write_text(h_data) + + +def compile_procs_to_strings(proc_list, h_file_name: str): assert isinstance(proc_list, list) assert all(isinstance(p, Procedure) for p in proc_list) - run_compile([p._loopir_proc for p in proc_list], path, c_file, h_file) + return run_compile([p._loopir_proc for p in proc_list], h_file_name) class Procedure(ProcedureBase): @@ -315,9 +322,8 @@ def c_code_str(self): decls, defns = compile_to_strings("c_code_str", [self._loopir_proc]) return decls + '\n' + defns - def compile_c(self, directory, filename): - run_compile([self._loopir_proc], directory, - (filename + ".c"), (filename + ".h")) + def compile_c(self, directory: Path, filename: str): + compile_procs([self._loopir_proc], directory, f'{filename}.c', f'{filename}.h') def interpret(self, **kwargs): run_interpreter(self._loopir_proc, kwargs) diff --git a/src/exo/LoopIR_compiler.py b/src/exo/LoopIR_compiler.py index 696ae8b4e..512f67d3f 100644 --- a/src/exo/LoopIR_compiler.py +++ b/src/exo/LoopIR_compiler.py @@ -2,6 +2,7 @@ import re from collections import ChainMap from collections import defaultdict +from pathlib import Path from .LoopIR import LoopIR, LoopIR_Do from .LoopIR import T @@ -219,17 +220,14 @@ def window_struct(basetyp, n_dims): # top level compiler function called by tests! -def run_compile(proc_list, path, c_file, h_file, header_guard=None): - file_stem = re.match(r'^([^\.]+)\.[^\.]+$', c_file) - if not file_stem: - raise ValueError("Expected file name to end " - "with extension: e.g. ___.__ ") - lib_name = sanitize_str(file_stem[1]) +def run_compile(proc_list, h_file_name: str): + file_stem = str(Path(h_file_name).stem) + lib_name = sanitize_str(file_stem) fwd_decls, body = compile_to_strings(lib_name, proc_list) - if header_guard is None: - header_guard = re.sub(r'\W', '_', h_file).upper() + body = f'#include "{h_file_name}"\n\n{body}' + header_guard = f'{lib_name}_H'.upper() fwd_decls = f''' #pragma once #ifndef {header_guard} @@ -247,103 +245,68 @@ def run_compile(proc_list, path, c_file, h_file, header_guard=None): #endif // {header_guard} ''' - body = f'#include "{h_file}"\n\n{body}' - - with open(os.path.join(path, h_file), "w") as f_header: - f_header.write(fwd_decls) - - with open(os.path.join(path, c_file), "w") as f_cpp: - f_cpp.write(body) + return body, fwd_decls def compile_to_strings(lib_name, proc_list): - # get transitive closure of call-graph + # Get transitive closure of call-graph orig_procs = [id(p) for p in proc_list] - proc_list = find_all_subprocs(proc_list) - mem_list = find_all_mems(proc_list) - builtin_list = find_all_builtins(proc_list) - config_list = find_all_configs(proc_list) - # check for name conflicts between procs - used_names = set() - for p in proc_list: - if p.name in used_names: - raise Exception(f"Cannot compile multiple " - f"procedures named '{p.name}'") - used_names.add(p.name) - - body = [ - "static int _floor_div(int num, int quot) {", - " int off = (num>=0)? 0 : quot-1;", - " return (num-off)/quot;", - "}", - "", - "static int8_t _clamp_32to8(int32_t x) {", - " return (x < -128)? -128 : ((x > 127)? 127 : x);", - "}", - "", - ] + # Determine name for library context struct + ctxt_name = f"{lib_name}_Context" - fwd_decls = [] - struct_defns = set() + def from_lines(x): + return '\n'.join(x) - m: Memory - for m in mem_list: - body.append(f'{m.global_()}\n') + proc_list = list(sorted(find_all_subprocs(proc_list), key=lambda x: x.name)) - for b in builtin_list: - glb = b.globl() - if glb: - body.append(glb) - body.append("\n") + # Header contents + ctxt_def = _compile_context_struct(find_all_configs(proc_list), ctxt_name) + struct_defns = set() + fwd_decls = [] - # Build Context Struct - ctxt_name = f"{lib_name}_Context" - ctxt_def = [f"typedef struct {ctxt_name} {{ ", - f""] - for c in config_list: - if c.is_allow_rw(): - sdef_lines = c.c_struct_def() - sdef_lines = [f" {line}" for line in sdef_lines] - ctxt_def += sdef_lines - ctxt_def += [""] - else: - ctxt_def += [f"// config '{c.name()}' not materialized", - ""] - ctxt_def += [f"}} {ctxt_name};"] - fwd_decls += ctxt_def - fwd_decls.append("\n") - # check that we don't have a name conflict on configs - config_names = {c.name() for c in config_list} - if len(config_names) != len(config_list): - raise TypeError("Cannot compile while using two configs " - "with the same name") + # Body contents + memory_code = _compile_memories(find_all_mems(proc_list)) + builtin_code = _compile_builtins(find_all_builtins(proc_list)) + priv_decls = [] + proc_bodies = [] + # Compile proc bodies + seen_procs = set() for p in proc_list: - # don't compile instruction procedures, but add a comment? + if p.name in seen_procs: + raise TypeError(f"multiple procs named {p.name}") + seen_procs.add(p.name) + + # don't compile instruction procedures, but add a comment. if p.instr is not None: argstr = ','.join([str(a.name) for a in p.args]) - body.append("\n/* relying on the following instruction...\n" - f"{p.name}({argstr})\n" - f'{p.instr}\n' - "*/\n") + proc_bodies.extend([ + '', + '/* relying on the following instruction..."', + f"{p.name}({argstr})", + p.instr, + "*/", + ]) else: - p_to_start = p + orig_p = id(p) p = PrecisionAnalysis(p).result() p = WindowAnalysis(p).result() p = MemoryAnalysis(p).result() comp = Compiler(p, ctxt_name) d, b = comp.comp_top() - struct_defns = struct_defns.union(comp.struct_defns()) + struct_defns |= comp.struct_defns() # only dump .h-file forward declarations for requested procedures - if id(p_to_start) in orig_procs: + if orig_p in orig_procs: fwd_decls.append(d) - body.append(b) + else: + priv_decls.append(d) + proc_bodies.append(b) - # add struct definitions before the other forward declarations - fwd_decls = list(struct_defns) + fwd_decls - fwd_decls = '\n'.join(fwd_decls) - fwd_decls = f''' + # Structs are just blobs of code... still sort them for output stability + struct_defns = list(sorted(struct_defns)) + + header_contents = f''' #include #include @@ -365,12 +328,66 @@ def compile_to_strings(lib_name, proc_list): # define EXO_ASSUME(expr) ((void)(expr)) #endif -{fwd_decls} +{from_lines(ctxt_def)} +{from_lines(struct_defns)} +{from_lines(fwd_decls)} ''' - body = '\n'.join(body) + body_contents = f''' +static int _floor_div(int num, int quot) {{ + int off = (num>=0)? 0 : quot-1; + return (num-off)/quot; +}} + +static int8_t _clamp_32to8(int32_t x) {{ + return (x < -128)? -128 : ((x > 127)? 127 : x); +}} + +{from_lines(memory_code)} +{from_lines(builtin_code)} +{from_lines(priv_decls)} +{from_lines(proc_bodies)} +''' + + return header_contents, body_contents + + +def _compile_builtins(builtins): + builtin_code = [] + for b in sorted(builtins, key=lambda x: x.name()): + if glb := b.globl(): + builtin_code.append(glb) + return builtin_code + + +def _compile_memories(mems): + memory_code = [] + for m in sorted(mems, key=lambda x: x.name()): + memory_code.append(m.global_()) + return memory_code + + +def _compile_context_struct(configs, ctxt_name): + ctxt_def = [f"typedef struct {ctxt_name} {{ ", + f""] + seen = set() + for c in sorted(configs, key=lambda x: x.name()): + name = c.name() - return fwd_decls, body + if name in seen: + raise TypeError(f"multiple configs named {name}") + seen.add(name) + + if c.is_allow_rw(): + sdef_lines = c.c_struct_def() + sdef_lines = [f" {line}" for line in sdef_lines] + ctxt_def += sdef_lines + ctxt_def += [""] + else: + ctxt_def += [f"// config '{name}' not materialized", + ""] + ctxt_def += [f"}} {ctxt_name};"] + return ctxt_def # --------------------------------------------------------------------------- # diff --git a/src/exo/LoopIR_scheduling.py b/src/exo/LoopIR_scheduling.py index 98c7012f0..5243d87ea 100644 --- a/src/exo/LoopIR_scheduling.py +++ b/src/exo/LoopIR_scheduling.py @@ -2848,7 +2848,7 @@ def has_div_mod_config(self, e): def index_start(self, e): assert isinstance(e, LoopIR.expr) # Div and mod need more subtle handling. Don't normalize for now. - # Skip ReadConfigs, they needs careful handling because they're not Sym. + # Skip ReadConfigs, they need careful handling because they're not Sym. if self.has_div_mod_config(e): return e @@ -2856,27 +2856,29 @@ def index_start(self, e): n_map = self.normalize_e(e) # Write back to LoopIR.expr - def get_loopir(key, value): - vconst = LoopIR.Const(value, T.int, e.srcinfo) - if key == self.C: - return vconst - else: - readkey = LoopIR.Read(key, [], e.type, e.srcinfo) - return LoopIR.BinOp('*', vconst, readkey, e.type, e.srcinfo) + def scale_read(coeff, key): + return LoopIR.BinOp( + '*', + LoopIR.Const(coeff, T.int, e.srcinfo), + LoopIR.Read(key, [], e.type, e.srcinfo), + e.type, + e.srcinfo + ) - delete_zero = { key: n_map[key] for key in n_map if n_map[key] != 0 } - new_e = LoopIR.Const(0, T.int, e.srcinfo) - for key, val in delete_zero.items(): - if val > 0: - # add - new_e = LoopIR.BinOp('+', new_e, get_loopir(key, val), e.type, e.srcinfo) + new_e = LoopIR.Const(n_map.get(self.C, 0), T.int, e.srcinfo) + + delete_zero = [(n_map[v], v) + for v in n_map + if v != self.C and n_map[v] != 0] + + for coeff, v in sorted(delete_zero): + if coeff > 0: + new_e = LoopIR.BinOp('+', new_e, scale_read(coeff, v), e.type, e.srcinfo) else: - # sub - new_e = LoopIR.BinOp('-', new_e, get_loopir(key, -val), e.type, e.srcinfo) + new_e = LoopIR.BinOp('-', new_e, scale_read(-coeff, v), e.type, e.srcinfo) return new_e - def map_e(self, e): if e.type.is_indexable(): return self.index_start(e) diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 190d9b37a..d30c82fd9 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -1,4 +1,4 @@ -from .API import Procedure, compile_procs, proc, instr, config +from .API import Procedure, compile_procs, compile_procs_to_strings, proc, instr, config from .LoopIR_scheduling import SchedulingError from .parse_fragment import ParseFragmentError from .configs import Config @@ -12,6 +12,7 @@ __all__ = [ "Procedure", "compile_procs", + "compile_procs_to_strings", "proc", "instr", "config", diff --git a/src/exo/pattern_match.py b/src/exo/pattern_match.py index 8bd8a724f..b6317a806 100644 --- a/src/exo/pattern_match.py +++ b/src/exo/pattern_match.py @@ -270,6 +270,12 @@ def match_e(self, pat, e): if isinstance(pat, PAST.E_Hole): return True + # Special case: -3 can be parsed as USub(Const(3))... it should match Const(-3) + if (isinstance(pat, PAST.USub) + and isinstance(pat.arg, PAST.Const) + and isinstance(e, LoopIR.Const)): + pat = pat.arg.update(val=-pat.arg.val) + # first ensure that the pattern and statement # are the same constructor if not isinstance(e, (LoopIR.WindowExpr,) + tuple(_PAST_to_LoopIR[type(pat)])): @@ -290,6 +296,7 @@ def match_e(self, pat, e): elif isinstance(e, LoopIR.Const): return pat.val == e.val elif isinstance(e, LoopIR.BinOp): + # TODO: do we need to handle associativity? (a + b) + c vs a + (b + c)? return ( pat.op == e.op and self.match_e(pat.lhs, e.lhs) and self.match_e(pat.rhs, e.rhs) ) diff --git a/tests/gemmini/conv/test_gemmini_conv_stride1.py b/tests/gemmini/conv/test_gemmini_conv_stride1.py index 6b6a05f78..58112c5be 100644 --- a/tests/gemmini/conv/test_gemmini_conv_stride1.py +++ b/tests/gemmini/conv/test_gemmini_conv_stride1.py @@ -125,12 +125,12 @@ def test_conv_3(): conv = old_reorder(conv, 'och_o kch_o') conv = old_lift_alloc(conv, 'w_s : _', n_lifts=5) conv = old_fission_after(conv, 'w_s = _', n_lifts=5) - conv = old_fission_after(conv, 'if 0 <= orow + krow - 1 and orow + krow - 1 < 56: _', n_lifts=5) - conv = lift_if(conv, - 'if 0 <= orow + krow - 1 and orow + krow - 1 < 56: _', n_lifts=3) - conv = lift_if(conv, - 'if 0 <= orow + krow - 1 and orow + krow - 1 < 56: _ #1', - n_lifts=3) + conv = old_fission_after( + conv, 'if 0 <= -1 + krow + orow and -1 + krow + orow < 56: _', n_lifts=5) + conv = lift_if( + conv, 'if 0 <= -1 + krow + orow and -1 + krow + orow < 56: _', n_lifts=3) + conv = lift_if( + conv, 'if 0 <= -1 + krow + orow and -1 + krow + orow < 56: _ #1', n_lifts=3) conv = old_reorder(conv, 'kch_o ocol_i') conv = specialize(conv, 'for ocol_i in _:_ #1', 'ocol_o == 0 and kcol == 0') @@ -142,13 +142,13 @@ def test_conv_3(): conv = assert_if(conv, 'if _:_ #2', True) conv = assert_if(conv, 'if _:_ #2', True) conv = assert_if(conv, 'if _:_ #2', True) - conv = assert_if(conv, 'if 0 <= ocol_i + 47 + kcol:_', True) + conv = assert_if(conv, 'if 0 <= 47 + kcol + ocol_i:_', True) conv = specialize(conv, 'for ocol_i in _:_ #7', 'kcol == 2') conv = cut_loop(conv, 'ocol_i #7', 7) - conv = repeat(assert_if)(conv, 'if ocol_i + 47 + kcol < 56:_', True) + conv = repeat(assert_if)(conv, 'if 47 + kcol + ocol_i < 56:_', True) conv = unroll_loop(conv, 'ocol_i #8') - print(conv) - conv = assert_if(conv, 'if 0 + 7 + 47 + kcol < 56:_', False) + # print(conv) + conv = assert_if(conv, 'if 47 + kcol + (0 + 7) < 56:_', False) conv = replace(conv, 'for ocol_i in _:_ #0', ld_acc_i32_vector) conv = old_reorder(conv, 'och_i kch_i') diff --git a/tests/gemmini/harness_gemmini.py b/tests/gemmini/harness_gemmini.py index aad71d7a7..6465a7723 100644 --- a/tests/gemmini/harness_gemmini.py +++ b/tests/gemmini/harness_gemmini.py @@ -1,5 +1,6 @@ import os import subprocess +from pathlib import Path from exo import compile_procs @@ -227,14 +228,13 @@ def add_proc(self, p): self.procs.append(p) def compile(self): - path = ENV.TMP_DIR lib_file = f"{self.test_name}_lib.c" h_file = f"{self.test_name}_lib.h" main_file = f"{self.test_name}_main.c" bin_file = self.test_name # write lib.c and lib.h - compile_procs(self.procs, ENV.GEMM_BUILD_DIR, lib_file, h_file) + compile_procs(self.procs, Path(ENV.GEMM_BUILD_DIR), lib_file, h_file) # write main.c main_src = gemmini_test_template(h_file, self.glob, self.body) diff --git a/tests/golden/test_apps/test_neon_sgemm.txt b/tests/golden/test_apps/test_neon_sgemm.txt new file mode 100644 index 000000000..d8f7e687a --- /dev/null +++ b/tests/golden/test_apps/test_neon_sgemm.txt @@ -0,0 +1,232 @@ + +#pragma once +#ifndef TEST_CASE_H +#define TEST_CASE_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + +typedef struct test_case_Context { + +} test_case_Context; +struct exo_win_1f32{ + float *data; + int_fast32_t strides[1]; +}; +struct exo_win_2f32{ + float *data; + int_fast32_t strides[2]; +}; +// sgemm_exo( +// M : size, +// N : size, +// K : size, +// A : f32[M,K] @DRAM, +// B : f32[K,N] @DRAM, +// C : f32[M,N] @DRAM +// ) +void sgemm_exo( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, float* A, float* B, float* C ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_CASE_H + +#include "test_case.h" + + +static int _floor_div(int num, int quot) { + int off = (num>=0)? 0 : quot-1; + return (num-off)/quot; +} + +static int8_t _clamp_32to8(int32_t x) { + return (x < -128)? -128 : ((x > 127)? 127 : x); +} + +#include +#include + +#include + +// neon_microkernel( +// K : size, +// A : [f32][4,K] @DRAM, +// B : [f32][K,16] @DRAM, +// C : [f32][4,16] @DRAM +// ) +void neon_microkernel( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + + +/* relying on the following instruction..." +neon_broadcast_4xf32(dst,src) +{dst_data} = vld1q_dup_f32(&{src_data}); +*/ +// neon_microkernel( +// K : size, +// A : [f32][4,K] @DRAM, +// B : [f32][K,16] @DRAM, +// C : [f32][4,16] @DRAM +// ) +void neon_microkernel( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +float32x4_t C_reg[4][4]; +for (int i = 0; i < 4; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = vld1q_f32(&C.data[(i) * (C.strides[0]) + (4 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + float32x4_t A_vec; + for (int i = 0; i < 4; i++) { + A_vec = vld1q_dup_f32(&A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + } + float32x4_t B_vec; + for (int jo = 0; jo < 4; jo++) { + B_vec = vld1q_f32(&B.data[(k) * (B.strides[0]) + (4 * jo) * (B.strides[1])]); + } + for (int i = 0; i < 4; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = vmlaq_f32(C_reg[i][jo], A_vec, B_vec); + } + } +} +for (int i = 0; i < 4; i++) { + for (int jo = 0; jo < 4; jo++) { + vst1q_f32(&C.data[(i) * (C.strides[0]) + (4 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + + +/* relying on the following instruction..." +neon_vfmadd_4xf32_4xf32(dst,lhs,rhs) +{dst_data} = vmlaq_f32({dst_data}, {lhs_data}, {rhs_data}); +*/ + +/* relying on the following instruction..." +neon_vld_4xf32(dst,src) +{dst_data} = vld1q_f32(&{src_data}); +*/ + +/* relying on the following instruction..." +neon_vst_4xf32(dst,src) +vst1q_f32(&{dst_data}, {src_data}); +*/ +// sgemm_exo( +// M : size, +// N : size, +// K : size, +// A : f32[M,K] @DRAM, +// B : f32[K,N] @DRAM, +// C : f32[M,N] @DRAM +// ) +void sgemm_exo( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, float* A, float* B, float* C ) { +EXO_ASSUME(M >= 1); +EXO_ASSUME(N >= 1); +EXO_ASSUME(K >= 1); +EXO_ASSUME(1 == 1); +EXO_ASSUME(1 == 1); +EXO_ASSUME(1 == 1); +float *Atile = malloc(64 * 64 * sizeof(*Atile)); +float *Btile = malloc(64 * 64 * sizeof(*Btile)); +for (int ko = 0; ko < ((K) / (64)); ko++) { + for (int io = 0; io < ((((M) / (4))) / (16)); io++) { + for (int i0 = 0; i0 < 64; i0++) { + for (int i1 = 0; i1 < 64; i1++) { + Atile[(i0) * (64) + (i1) * (1)] = A[(i0 + 64 * io) * (K) + (i1 + 64 * ko) * (1)]; + } + } + for (int jo = 0; jo < ((((N) / (16))) / (4)); jo++) { + for (int i0 = 0; i0 < 64; i0++) { + for (int i1 = 0; i1 < 64; i1++) { + Btile[(i0) * (64) + (i1) * (1)] = B[(i0 + 64 * ko) * (N) + (i1 + 64 * jo) * (1)]; + } + } + for (int im = 0; im < 16; im++) { + for (int jm = 0; jm < 4; jm++) { + neon_microkernel(ctxt,64,(struct exo_win_2f32){ (float*)&Atile[(4 * im) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&Btile[(0) * (64) + (16 * jm) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(4 * im + 64 * io) * (N) + (16 * jm + 64 * jo) * (1)], { N, 1 } }); + } + } + } + } +} +free(Btile); +free(Atile); +for (int ko = 0; ko < ((K) / (64)); ko++) { + for (int io = 0; io < ((((M) / (4))) / (16)); io++) { + for (int jm = 0; jm < ((N) / (16)) % 4; jm++) { + for (int im = 0; im < 16; im++) { + neon_microkernel(ctxt,64,(struct exo_win_2f32){ (float*)&A[(4 * im + 64 * io) * (K) + (64 * ko) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B[(64 * ko) * (N) + (16 * (jm + ((((N) / (16))) / (4)) * 4)) * (1)], { N, 1 } },(struct exo_win_2f32){ (float*)&C[(4 * im + 64 * io) * (N) + (16 * (jm + ((((N) / (16))) / (4)) * 4)) * (1)], { N, 1 } }); + } + } + } + for (int jo = 0; jo < ((N) / (16)); jo++) { + for (int im = 0; im < ((M) / (4)) % 16; im++) { + neon_microkernel(ctxt,64,(struct exo_win_2f32){ (float*)&A[(4 * (im + ((((M) / (4))) / (16)) * 16)) * (K) + (64 * ko) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B[(64 * ko) * (N) + (16 * jo) * (1)], { N, 1 } },(struct exo_win_2f32){ (float*)&C[(4 * (im + ((((M) / (4))) / (16)) * 16)) * (N) + (16 * jo) * (1)], { N, 1 } }); + } + } +} +for (int io = 0; io < ((M) / (4)); io++) { + for (int jo = 0; jo < ((N) / (16)); jo++) { + for (int ii = 0; ii < 4; ii++) { + for (int ji = 0; ji < 16; ji++) { + if (K % 64 > 0) { + for (int ki = 0; ki < K % 64; ki++) { + C[(ii + 4 * io) * (N) + (ji + 16 * jo) * (1)] += A[(ii + 4 * io) * (K) + (ki + ((K) / (64)) * 64) * (1)] * B[(ki + ((K) / (64)) * 64) * (N) + (ji + 16 * jo) * (1)]; + } + } + } + } + } +} +for (int io = 0; io < ((M) / (4)); io++) { + for (int ii = 0; ii < 4; ii++) { + if (N % 16 > 0) { + for (int ji = 0; ji < N % 16; ji++) { + for (int k = 0; k < K; k++) { + C[(ii + 4 * io) * (N) + (ji + ((N) / (16)) * 16) * (1)] += A[(ii + 4 * io) * (K) + (k) * (1)] * B[(k) * (N) + (ji + ((N) / (16)) * 16) * (1)]; + } + } + } + } +} +if (M % 4 > 0) { + for (int ii = 0; ii < M % 4; ii++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + C[(ii + ((M) / (4)) * 4) * (N) + (j) * (1)] += A[(ii + ((M) / (4)) * 4) * (K) + (k) * (1)] * B[(k) * (N) + (j) * (1)]; + } + } + } +} +} + diff --git a/tests/golden/test_apps/test_x86_conv.txt b/tests/golden/test_apps/test_x86_conv.txt new file mode 100644 index 000000000..8cb0e221f --- /dev/null +++ b/tests/golden/test_apps/test_x86_conv.txt @@ -0,0 +1,147 @@ + +#pragma once +#ifndef TEST_CASE_H +#define TEST_CASE_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + +typedef struct test_case_Context { + +} test_case_Context; +struct exo_win_1f32{ + float *data; + int_fast32_t strides[1]; +}; +// conv_specialized( +// inp : f32[5,82,102,128] @DRAM, +// output : f32[5,80,100,128] @DRAM, +// weights : f32[128,3,3,128] @DRAM, +// bias : f32[128] @DRAM +// ) +void conv_specialized( test_case_Context *ctxt, float* inp, float* output, float* weights, float* bias ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_CASE_H + +#include "test_case.h" + + +static int _floor_div(int num, int quot) { + int off = (num>=0)? 0 : quot-1; + return (num-off)/quot; +} + +static int8_t _clamp_32to8(int32_t x) { + return (x < -128)? -128 : ((x > 127)? 127 : x); +} + +#include +#include +#include + +double _relu_(double x) { + if (x > 0.0) return x; + else return 0.0; +} + + +// conv_specialized( +// inp : f32[5,82,102,128] @DRAM, +// output : f32[5,80,100,128] @DRAM, +// weights : f32[128,3,3,128] @DRAM, +// bias : f32[128] @DRAM +// ) +void conv_specialized( test_case_Context *ctxt, float* inp, float* output, float* weights, float* bias ) { +for (int oc_o = 0; oc_o < 2; oc_o++) { + for (int n = 0; n < 5; n++) { + for (int oy = 0; oy < 80; oy++) { + for (int ox_o = 0; ox_o < 20; ox_o++) { + __m512 res[5][4]; + for (int ox_i = 0; ox_i < 5; ox_i++) { + for (int oc_u = 0; oc_u < 4; oc_u++) { + res[ox_i][oc_u] = _mm512_loadu_ps(&bias[(16 * oc_u + 64 * oc_o) * (1)]); + } + } + for (int ky = 0; ky < 3; ky++) { + for (int kx = 0; kx < 3; kx++) { + for (int kc_o = 0; kc_o < 64; kc_o++) { + for (int kc_i = 0; kc_i < 2; kc_i++) { + for (int ox_i = 0; ox_i < 5; ox_i++) { + for (int oc_u = 0; oc_u < 4; oc_u++) { + __m512 wt_vec; + wt_vec = _mm512_loadu_ps(&weights[(kc_i + 2 * kc_o) * (3 * 3 * 128) + (ky) * (3 * 128) + (kx) * (128) + (16 * oc_u + 64 * oc_o) * (1)]); + __m512 in_vec; + (in_vec) = _mm512_set1_ps(inp[(n) * (82 * 102 * 128) + (ky + oy) * (102 * 128) + (kx + ox_i + 5 * ox_o) * (128) + (kc_i + 2 * kc_o) * (1)]); + res[ox_i][oc_u] = _mm512_fmadd_ps((wt_vec), (in_vec), res[ox_i][oc_u]); + } + } + } + } + } + } + for (int ox_i = 0; ox_i < 5; ox_i++) { + for (int oc_u = 0; oc_u < 4; oc_u++) { + __m512 relu_v; + relu_v = _mm512_max_ps(res[ox_i][oc_u], (__m512){0}); + _mm512_storeu_ps(&output[(n) * (80 * 100 * 128) + (oy) * (100 * 128) + (ox_i + 5 * ox_o) * (128) + (16 * oc_u + 64 * oc_o) * (1)], relu_v); + } + } + } + } + } +} +} + + +/* relying on the following instruction..." +mm512_fmadd_ps(A,B,C) +{C_data} = _mm512_fmadd_ps({A}, {B}, {C_data}); +*/ + +/* relying on the following instruction..." +mm512_loadu_ps(dst,src) +{dst_data} = _mm512_loadu_ps(&{src_data}); +*/ + +/* relying on the following instruction..." +mm512_relu_ps(dst,src) +{dst_data} = _mm512_max_ps({src_data}, (__m512){{0}}); +*/ + +/* relying on the following instruction..." +mm512_set1_ps(dst,src) +{dst} = _mm512_set1_ps({src_data}); +*/ + +/* relying on the following instruction..." +mm512_storeu_ps(dst,src) +_mm512_storeu_ps(&{dst_data}, {src_data}); +*/ diff --git a/tests/golden/test_apps/test_x86_sgemm.txt b/tests/golden/test_apps/test_x86_sgemm.txt new file mode 100644 index 000000000..9a95a866c --- /dev/null +++ b/tests/golden/test_apps/test_x86_sgemm.txt @@ -0,0 +1,781 @@ + +#pragma once +#ifndef TEST_CASE_H +#define TEST_CASE_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + +typedef struct test_case_Context { + +} test_case_Context; +struct exo_win_1f32{ + float *data; + int_fast32_t strides[1]; +}; +struct exo_win_2f32{ + float *data; + int_fast32_t strides[2]; +}; +// sgemm_exo( +// M : size, +// N : size, +// K : size, +// A : f32[M,K] @DRAM, +// B : f32[K,N] @DRAM, +// C : f32[M,N] @DRAM +// ) +void sgemm_exo( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, float* A, float* B, float* C ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_CASE_H + +#include "test_case.h" + + +static int _floor_div(int num, int quot) { + int off = (num>=0)? 0 : quot-1; + return (num-off)/quot; +} + +static int8_t _clamp_32to8(int32_t x) { + return (x < -128)? -128 : ((x > 127)? 127 : x); +} + +#include +#include +#include + +#include +#include + + +// bottom_panel_kernel_scheduled( +// M : size, +// K : size, +// A : [f32][M,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][M,64] @DRAM +// ) +void bottom_panel_kernel_scheduled( test_case_Context *ctxt, int_fast32_t M, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// right_panel_kernel_scheduled( +// N : size, +// K : size, +// A : [f32][6,K] @DRAM, +// B : [f32][K,N] @DRAM, +// C : [f32][6,N] @DRAM +// ) +void right_panel_kernel_scheduled( test_case_Context *ctxt, int_fast32_t N, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_above_kernel( +// M : size, +// N : size, +// K : size, +// A : [f32][M,K] @DRAM, +// B : [f32][K,N] @DRAM, +// C : [f32][M,N] @DRAM +// ) +void sgemm_above_kernel( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_1x4( +// K : size, +// A : [f32][1,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][1,64] @DRAM +// ) +void sgemm_kernel_avx512_1x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_2x4( +// K : size, +// A : [f32][2,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][2,64] @DRAM +// ) +void sgemm_kernel_avx512_2x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_3x4( +// K : size, +// A : [f32][3,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][3,64] @DRAM +// ) +void sgemm_kernel_avx512_3x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_4x4( +// K : size, +// A : [f32][4,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][4,64] @DRAM +// ) +void sgemm_kernel_avx512_4x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_5x4( +// K : size, +// A : [f32][5,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][5,64] @DRAM +// ) +void sgemm_kernel_avx512_5x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// sgemm_kernel_avx512_6x4( +// K : size, +// A : [f32][6,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][6,64] @DRAM +// ) +void sgemm_kernel_avx512_6x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ); + +// bottom_panel_kernel_scheduled( +// M : size, +// K : size, +// A : [f32][M,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][M,64] @DRAM +// ) +void bottom_panel_kernel_scheduled( test_case_Context *ctxt, int_fast32_t M, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(M >= 1); +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +EXO_ASSUME(M < 6); +if (M == 1) { + sgemm_kernel_avx512_1x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(0) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (0) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(0) * (C.strides[0]) + (0) * (C.strides[1])], { C.strides[0], C.strides[1] } }); +} else { + if (M == 2) { + sgemm_kernel_avx512_2x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(0) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (0) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(0) * (C.strides[0]) + (0) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } else { + if (M == 3) { + sgemm_kernel_avx512_3x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(0) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (0) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(0) * (C.strides[0]) + (0) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } else { + if (M == 4) { + sgemm_kernel_avx512_4x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(0) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (0) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(0) * (C.strides[0]) + (0) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } else { + if (M == 5) { + sgemm_kernel_avx512_5x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(0) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (0) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(0) * (C.strides[0]) + (0) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } else { + for (int k = 0; k < K; k++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < 64; j++) { + C.data[(i) * (C.strides[0]) + (j) * (C.strides[1])] += A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])] * B.data[(k) * (B.strides[0]) + (j) * (B.strides[1])]; + } + } + } + } + } + } + } +} +} + + +/* relying on the following instruction..." +mm512_fmadd_ps(A,B,C) +{C_data} = _mm512_fmadd_ps({A}, {B}, {C_data}); +*/ + +/* relying on the following instruction..." +mm512_loadu_ps(dst,src) +{dst_data} = _mm512_loadu_ps(&{src_data}); +*/ + +/* relying on the following instruction..." +mm512_mask_fmadd_ps(N,A,B,C) +{C_data} = _mm512_mask_fmadd_ps({A}, ((1 << {N}) - 1), {B}, {C_data}); +*/ + +/* relying on the following instruction..." +mm512_mask_set1_ps(N,dst,src) +{dst} = _mm512_set1_ps({src_data}); +*/ + +/* relying on the following instruction..." +mm512_mask_storeu_ps(N,dst,src) +_mm512_mask_storeu_ps(&{dst_data}, ((1 << {N}) - 1), {src_data}); +*/ + +/* relying on the following instruction..." +mm512_maskz_loadu_ps(N,dst,src) +{dst_data} = _mm512_maskz_loadu_ps(((1 << {N}) - 1), &{src_data}); +*/ + +/* relying on the following instruction..." +mm512_set1_ps(dst,src) +{dst} = _mm512_set1_ps({src_data}); +*/ + +/* relying on the following instruction..." +mm512_storeu_ps(dst,src) +_mm512_storeu_ps(&{dst_data}, {src_data}); +*/ +// right_panel_kernel_scheduled( +// N : size, +// K : size, +// A : [f32][6,K] @DRAM, +// B : [f32][K,N] @DRAM, +// C : [f32][6,N] @DRAM +// ) +void right_panel_kernel_scheduled( test_case_Context *ctxt, int_fast32_t N, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(N >= 1); +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +EXO_ASSUME(((N) / (16)) < 4); +if (((N) / (16)) == 0) { + __m512 C_reg[6][1]; + __m512 C_reg_1[6]; + for (int i = 0; i < 6; i++) { + C_reg_1[i] = _mm512_maskz_loadu_ps(((1 << (N)) - 1), &C.data[(i) * (C.strides[0]) + (0) * (C.strides[1])]); + } + for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + __m512 A_reg2; + (A_reg2) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg2; + B_reg2 = _mm512_maskz_loadu_ps(((1 << (N)) - 1), &B.data[(k) * (B.strides[0]) + (0) * (B.strides[1])]); + C_reg_1[i] = _mm512_mask_fmadd_ps((A_reg2), ((1 << (N)) - 1), (B_reg2), C_reg_1[i]); + } + } + for (int i = 0; i < 6; i++) { + _mm512_mask_storeu_ps(&C.data[(i) * (C.strides[0]) + (0) * (C.strides[1])], ((1 << (N)) - 1), C_reg_1[i]); + } +} else { + if (((N) / (16)) == 1) { + __m512 C_reg[6][2]; + __m512 C_reg_1[6]; + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 1; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } + C_reg_1[i] = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &C.data[(i) * (C.strides[0]) + (16) * (C.strides[1])]); + } + for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 1; jo++) { + __m512 A_reg; + (A_reg) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg; + B_reg = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_reg), (B_reg), C_reg[i][jo]); + } + __m512 A_reg2; + (A_reg2) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg2; + B_reg2 = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &B.data[(k) * (B.strides[0]) + (16) * (B.strides[1])]); + C_reg_1[i] = _mm512_mask_fmadd_ps((A_reg2), ((1 << (N % 16)) - 1), (B_reg2), C_reg_1[i]); + } + } + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 1; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } + _mm512_mask_storeu_ps(&C.data[(i) * (C.strides[0]) + (16) * (C.strides[1])], ((1 << (N % 16)) - 1), C_reg_1[i]); + } + } else { + if (((N) / (16)) == 2) { + __m512 C_reg[6][3]; + __m512 C_reg_1[6]; + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 2; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } + C_reg_1[i] = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &C.data[(i) * (C.strides[0]) + (32) * (C.strides[1])]); + } + for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 2; jo++) { + __m512 A_reg; + (A_reg) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg; + B_reg = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_reg), (B_reg), C_reg[i][jo]); + } + __m512 A_reg2; + (A_reg2) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg2; + B_reg2 = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &B.data[(k) * (B.strides[0]) + (32) * (B.strides[1])]); + C_reg_1[i] = _mm512_mask_fmadd_ps((A_reg2), ((1 << (N % 16)) - 1), (B_reg2), C_reg_1[i]); + } + } + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 2; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } + _mm512_mask_storeu_ps(&C.data[(i) * (C.strides[0]) + (32) * (C.strides[1])], ((1 << (N % 16)) - 1), C_reg_1[i]); + } + } else { + if (((N) / (16)) == 3) { + __m512 C_reg[6][4]; + __m512 C_reg_1[6]; + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 3; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } + C_reg_1[i] = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &C.data[(i) * (C.strides[0]) + (48) * (C.strides[1])]); + } + for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 3; jo++) { + __m512 A_reg; + (A_reg) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg; + B_reg = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_reg), (B_reg), C_reg[i][jo]); + } + __m512 A_reg2; + (A_reg2) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg2; + B_reg2 = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &B.data[(k) * (B.strides[0]) + (48) * (B.strides[1])]); + C_reg_1[i] = _mm512_mask_fmadd_ps((A_reg2), ((1 << (N % 16)) - 1), (B_reg2), C_reg_1[i]); + } + } + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 3; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } + _mm512_mask_storeu_ps(&C.data[(i) * (C.strides[0]) + (48) * (C.strides[1])], ((1 << (N % 16)) - 1), C_reg_1[i]); + } + } else { + __m512 C_reg[6][(((N) / (16)) + 1)]; + __m512 C_reg_1[6]; + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < ((N) / (16)); jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } + C_reg_1[i] = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &C.data[(i) * (C.strides[0]) + (16 * ((N) / (16))) * (C.strides[1])]); + } + for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < ((N) / (16)); jo++) { + __m512 A_reg; + (A_reg) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg; + B_reg = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_reg), (B_reg), C_reg[i][jo]); + } + __m512 A_reg2; + (A_reg2) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + __m512 B_reg2; + B_reg2 = _mm512_maskz_loadu_ps(((1 << (N % 16)) - 1), &B.data[(k) * (B.strides[0]) + (16 * ((N) / (16))) * (B.strides[1])]); + C_reg_1[i] = _mm512_mask_fmadd_ps((A_reg2), ((1 << (N % 16)) - 1), (B_reg2), C_reg_1[i]); + } + } + for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < ((N) / (16)); jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } + _mm512_mask_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * ((N) / (16))) * (C.strides[1])], ((1 << (N % 16)) - 1), C_reg_1[i]); + } + } + } + } +} +} + +// sgemm_above_kernel( +// M : size, +// N : size, +// K : size, +// A : [f32][M,K] @DRAM, +// B : [f32][K,N] @DRAM, +// C : [f32][M,N] @DRAM +// ) +void sgemm_above_kernel( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(M >= 1); +EXO_ASSUME(N >= 1); +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +for (int io = 0; io < ((M) / (6)); io++) { + for (int jo = 0; jo < ((N) / (64)); jo++) { + sgemm_kernel_avx512_6x4(ctxt,K,(struct exo_win_2f32){ (float*)&A.data[(6 * io) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (64 * jo) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(6 * io) * (C.strides[0]) + (64 * jo) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } +} +if (N % 64 > 0) { + for (int io = 0; io < ((M) / (6)); io++) { + right_panel_kernel_scheduled(ctxt,N % 64,K,(struct exo_win_2f32){ (float*)&A.data[(6 * io) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (64 * ((N) / (64))) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(6 * io) * (C.strides[0]) + (64 * ((N) / (64))) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } +} +if (M % 6 > 0) { + for (int jo = 0; jo < ((N) / (64)); jo++) { + bottom_panel_kernel_scheduled(ctxt,M % 6,K,(struct exo_win_2f32){ (float*)&A.data[(6 * ((M) / (6))) * (A.strides[0]) + (0) * (A.strides[1])], { A.strides[0], A.strides[1] } },(struct exo_win_2f32){ (float*)&B.data[(0) * (B.strides[0]) + (64 * jo) * (B.strides[1])], { B.strides[0], B.strides[1] } },(struct exo_win_2f32){ (float*)&C.data[(6 * ((M) / (6))) * (C.strides[0]) + (64 * jo) * (C.strides[1])], { C.strides[0], C.strides[1] } }); + } + if (N % 64 > 0) { + for (int k = 0; k < K; k++) { + for (int ii = 0; ii < M % 6; ii++) { + for (int ji = 0; ji < N % 64; ji++) { + C.data[(ii + ((M) / (6)) * 6) * (C.strides[0]) + (ji + ((N) / (64)) * 64) * (C.strides[1])] += A.data[(ii + ((M) / (6)) * 6) * (A.strides[0]) + (k) * (A.strides[1])] * B.data[(k) * (B.strides[0]) + (ji + ((N) / (64)) * 64) * (B.strides[1])]; + } + } + } + } +} +} + +// sgemm_exo( +// M : size, +// N : size, +// K : size, +// A : f32[M,K] @DRAM, +// B : f32[K,N] @DRAM, +// C : f32[M,N] @DRAM +// ) +void sgemm_exo( test_case_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, float* A, float* B, float* C ) { +EXO_ASSUME(M >= 1); +EXO_ASSUME(N >= 1); +EXO_ASSUME(K >= 1); +EXO_ASSUME(1 == 1); +EXO_ASSUME(1 == 1); +EXO_ASSUME(1 == 1); +static float A1_cache[264 * 512]; +static float B1_cache[512 * 64]; +for (int ko = 0; ko < ((K) / (512)); ko++) { + for (int io = 0; io < ((M) / (264)); io++) { + for (int i0 = 0; i0 < 264; i0++) { + for (int i1 = 0; i1 < 512; i1++) { + A1_cache[(i0) * (512) + (i1) * (1)] = A[(i0 + 264 * io) * (K) + (i1 + 512 * ko) * (1)]; + } + } + for (int jo = 0; jo < ((N) / (64)); jo++) { + for (int i0 = 0; i0 < 512; i0++) { + for (int i1 = 0; i1 < 64; i1++) { + B1_cache[(i0) * (64) + (i1) * (1)] = B[(i0 + 512 * ko) * (N) + (i1 + 64 * jo) * (1)]; + } + } + sgemm_above_kernel(ctxt,264,64,512,(struct exo_win_2f32){ (float*)&A1_cache[(0) * (512) + (0) * (1)], { 512, 1 } },(struct exo_win_2f32){ (float*)&B1_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * io) * (N) + (64 * jo) * (1)], { N, 1 } }); + } + } +} +if (N % 64 > 0) { + for (int ko = 0; ko < ((K) / (512)); ko++) { + static float B2_cache[512 * 64]; + for (int i0 = 0; i0 < 512; i0++) { + for (int i1 = 0; i1 < N - 64 * ((N) / (64)); i1++) { + B2_cache[(i0) * (64) + (i1) * (1)] = B[(i0 + 512 * ko) * (N) + (64 * ((N) / (64)) + i1) * (1)]; + } + } + for (int io = 0; io < ((M) / (264)); io++) { + sgemm_above_kernel(ctxt,264,N % 64,512,(struct exo_win_2f32){ (float*)&A[(264 * io) * (K) + (512 * ko) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B2_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * io) * (N) + (64 * ((N) / (64))) * (1)], { N, 1 } }); + } + } +} +if (M % 264 > 0) { + for (int ko = 0; ko < ((K) / (512)); ko++) { + for (int jo = 0; jo < ((N) / (64)); jo++) { + static float B3_cache[512 * 64]; + for (int i0 = 0; i0 < 512; i0++) { + for (int i1 = 0; i1 < 64; i1++) { + B3_cache[(i0) * (64) + (i1) * (1)] = B[(i0 + 512 * ko) * (N) + (i1 + 64 * jo) * (1)]; + } + } + sgemm_above_kernel(ctxt,M % 264,64,512,(struct exo_win_2f32){ (float*)&A[(264 * ((M) / (264))) * (K) + (512 * ko) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B3_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * ((M) / (264))) * (N) + (64 * jo) * (1)], { N, 1 } }); + } + } +} +if (M % 264 > 0) { + if (N % 64 > 0) { + for (int ko = 0; ko < ((K) / (512)); ko++) { + static float B4_cache[512 * 64]; + for (int i0 = 0; i0 < 512; i0++) { + for (int i1 = 0; i1 < N - 64 * ((N) / (64)); i1++) { + B4_cache[(i0) * (64) + (i1) * (1)] = B[(i0 + 512 * ko) * (N) + (64 * ((N) / (64)) + i1) * (1)]; + } + } + sgemm_above_kernel(ctxt,M % 264,N % 64,512,(struct exo_win_2f32){ (float*)&A[(264 * ((M) / (264))) * (K) + (512 * ko) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B4_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * ((M) / (264))) * (N) + (64 * ((N) / (64))) * (1)], { N, 1 } }); + } + } +} +if (K % 512 > 0) { + for (int io = 0; io < ((M) / (264)); io++) { + for (int jo = 0; jo < ((N) / (64)); jo++) { + static float B5_cache[512 * 64]; + for (int i0 = 0; i0 < K - 512 * ((K) / (512)); i0++) { + for (int i1 = 0; i1 < 64; i1++) { + B5_cache[(i0) * (64) + (i1) * (1)] = B[(512 * ((K) / (512)) + i0) * (N) + (i1 + 64 * jo) * (1)]; + } + } + sgemm_above_kernel(ctxt,264,64,K % 512,(struct exo_win_2f32){ (float*)&A[(264 * io) * (K) + (512 * ((K) / (512))) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B5_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * io) * (N) + (64 * jo) * (1)], { N, 1 } }); + } + } +} +if (K % 512 > 0) { + if (N % 64 > 0) { + for (int io = 0; io < ((M) / (264)); io++) { + static float B6_cache[512 * 64]; + for (int i0 = 0; i0 < K - 512 * ((K) / (512)); i0++) { + for (int i1 = 0; i1 < N - 64 * ((N) / (64)); i1++) { + B6_cache[(i0) * (64) + (i1) * (1)] = B[(512 * ((K) / (512)) + i0) * (N) + (64 * ((N) / (64)) + i1) * (1)]; + } + } + sgemm_above_kernel(ctxt,264,N % 64,K % 512,(struct exo_win_2f32){ (float*)&A[(264 * io) * (K) + (512 * ((K) / (512))) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B6_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * io) * (N) + (64 * ((N) / (64))) * (1)], { N, 1 } }); + } + } +} +if (K % 512 > 0) { + if (M % 264 > 0) { + for (int jo = 0; jo < ((N) / (64)); jo++) { + static float B7_cache[512 * 64]; + for (int i0 = 0; i0 < K - 512 * ((K) / (512)); i0++) { + for (int i1 = 0; i1 < 64; i1++) { + B7_cache[(i0) * (64) + (i1) * (1)] = B[(512 * ((K) / (512)) + i0) * (N) + (i1 + 64 * jo) * (1)]; + } + } + sgemm_above_kernel(ctxt,M % 264,64,K % 512,(struct exo_win_2f32){ (float*)&A[(264 * ((M) / (264))) * (K) + (512 * ((K) / (512))) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B7_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * ((M) / (264))) * (N) + (64 * jo) * (1)], { N, 1 } }); + } + } +} +if (K % 512 > 0) { + if (M % 264 > 0) { + if (N % 64 > 0) { + static float B8_cache[512 * 64]; + for (int i0 = 0; i0 < K - 512 * ((K) / (512)); i0++) { + for (int i1 = 0; i1 < N - 64 * ((N) / (64)); i1++) { + B8_cache[(i0) * (64) + (i1) * (1)] = B[(512 * ((K) / (512)) + i0) * (N) + (64 * ((N) / (64)) + i1) * (1)]; + } + } + sgemm_above_kernel(ctxt,M % 264,N % 64,K % 512,(struct exo_win_2f32){ (float*)&A[(264 * ((M) / (264))) * (K) + (512 * ((K) / (512))) * (1)], { K, 1 } },(struct exo_win_2f32){ (float*)&B8_cache[(0) * (64) + (0) * (1)], { 64, 1 } },(struct exo_win_2f32){ (float*)&C[(264 * ((M) / (264))) * (N) + (64 * ((N) / (64))) * (1)], { N, 1 } }); + } + } +} +} + +// sgemm_kernel_avx512_1x4( +// K : size, +// A : [f32][1,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][1,64] @DRAM +// ) +void sgemm_kernel_avx512_1x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[1][4]; +for (int i = 0; i < 1; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 1; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 1; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + +// sgemm_kernel_avx512_2x4( +// K : size, +// A : [f32][2,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][2,64] @DRAM +// ) +void sgemm_kernel_avx512_2x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[2][4]; +for (int i = 0; i < 2; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 2; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 2; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + +// sgemm_kernel_avx512_3x4( +// K : size, +// A : [f32][3,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][3,64] @DRAM +// ) +void sgemm_kernel_avx512_3x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[3][4]; +for (int i = 0; i < 3; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 3; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 3; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + +// sgemm_kernel_avx512_4x4( +// K : size, +// A : [f32][4,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][4,64] @DRAM +// ) +void sgemm_kernel_avx512_4x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[4][4]; +for (int i = 0; i < 4; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 4; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 4; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + +// sgemm_kernel_avx512_5x4( +// K : size, +// A : [f32][5,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][5,64] @DRAM +// ) +void sgemm_kernel_avx512_5x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[5][4]; +for (int i = 0; i < 5; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 5; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 5; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + +// sgemm_kernel_avx512_6x4( +// K : size, +// A : [f32][6,K] @DRAM, +// B : [f32][K,64] @DRAM, +// C : [f32][6,64] @DRAM +// ) +void sgemm_kernel_avx512_6x4( test_case_Context *ctxt, int_fast32_t K, struct exo_win_2f32 A, struct exo_win_2f32 B, struct exo_win_2f32 C ) { +EXO_ASSUME(K >= 1); +EXO_ASSUME(A.strides[1] == 1); +EXO_ASSUME(B.strides[1] == 1); +EXO_ASSUME(C.strides[1] == 1); +__m512 C_reg[6][4]; +for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 4; jo++) { + C_reg[i][jo] = _mm512_loadu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])]); + } +} +for (int k = 0; k < K; k++) { + for (int i = 0; i < 6; i++) { + __m512 A_vec; + (A_vec) = _mm512_set1_ps(A.data[(i) * (A.strides[0]) + (k) * (A.strides[1])]); + for (int jo = 0; jo < 4; jo++) { + __m512 B_vec; + B_vec = _mm512_loadu_ps(&B.data[(k) * (B.strides[0]) + (16 * jo) * (B.strides[1])]); + C_reg[i][jo] = _mm512_fmadd_ps((A_vec), (B_vec), C_reg[i][jo]); + } + } +} +for (int i = 0; i < 6; i++) { + for (int jo = 0; jo < 4; jo++) { + _mm512_storeu_ps(&C.data[(i) * (C.strides[0]) + (16 * jo) * (C.strides[1])], C_reg[i][jo]); + } +} +} + diff --git a/tests/golden/test_precision/test_good_prec2.txt b/tests/golden/test_precision/test_good_prec2.txt index 0eaa7d2e2..99983e817 100644 --- a/tests/golden/test_precision/test_good_prec2.txt +++ b/tests/golden/test_precision/test_good_prec2.txt @@ -24,7 +24,6 @@ typedef struct c_code_str_Context { } c_code_str_Context; - // hoge( // n : size, // x : f32[n] @DRAM, @@ -33,6 +32,7 @@ typedef struct c_code_str_Context { void hoge( c_code_str_Context *ctxt, int_fast32_t n, float* x, float* y ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -46,6 +46,14 @@ static int8_t _clamp_32to8(int32_t x) { #include +// dot( +// m : size, +// x : f32[m] @DRAM, +// y : f32[m] @DRAM, +// r : f32 @DRAM +// ) +void dot( c_code_str_Context *ctxt, int_fast32_t m, float* x, float* y, float* r ); + // dot( // m : size, // x : f32[m] @DRAM, @@ -68,3 +76,4 @@ void hoge( c_code_str_Context *ctxt, int_fast32_t n, float* x, float* y ) { float xy; dot(ctxt,n,x,y,&xy); } + diff --git a/tests/golden/test_schedules/test_expand_dim3.txt b/tests/golden/test_schedules/test_expand_dim3.txt index fb7507690..f1522fe1f 100644 --- a/tests/golden/test_schedules/test_expand_dim3.txt +++ b/tests/golden/test_schedules/test_expand_dim3.txt @@ -24,7 +24,6 @@ typedef struct c_code_str_Context { } c_code_str_Context; - // foo( // n : size, // m : size, @@ -33,6 +32,7 @@ typedef struct c_code_str_Context { void foo( c_code_str_Context *ctxt, int_fast32_t n, int_fast32_t m, int8_t* x ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -46,6 +46,7 @@ static int8_t _clamp_32to8(int32_t x) { #include + // foo( // n : size, // m : size, @@ -70,3 +71,4 @@ for (int i = 0; i < n; i++) { } } } + diff --git a/tests/golden/test_schedules/test_simplify.txt b/tests/golden/test_schedules/test_simplify.txt index 1794bcba2..072a60df7 100644 --- a/tests/golden/test_schedules/test_simplify.txt +++ b/tests/golden/test_schedules/test_simplify.txt @@ -1,6 +1,6 @@ def foo(n: size, m: size): x: R[n, 16, 10] @ DRAM - for i in seq(0, 5 * n + 8): + for i in seq(0, 8 + 5 * n): pass y: R[10] @ DRAM y[1] = 0.0 diff --git a/tests/golden/test_schedules/test_simplify3.txt b/tests/golden/test_schedules/test_simplify3.txt index 053d5bdfa..0f4029d3f 100644 --- a/tests/golden/test_schedules/test_simplify3.txt +++ b/tests/golden/test_schedules/test_simplify3.txt @@ -1,4 +1,4 @@ def foo(n: size, m: size): assert m == 1 and n == 1 y: R[10] @ DRAM - y[10 * m - 8 * n] = 2.0 + y[0 - 8 * n + 10 * m] = 2.0 diff --git a/tests/golden/test_schedules/test_stage_mem.txt b/tests/golden/test_schedules/test_stage_mem.txt index 90f7505e1..50fbf1925 100644 --- a/tests/golden/test_schedules/test_stage_mem.txt +++ b/tests/golden/test_schedules/test_stage_mem.txt @@ -10,8 +10,9 @@ def sqmat(n: size, A: R[n, n] @ DRAM, B: R[n, n] @ DRAM): for ii in seq(0, 4): for jj in seq(0, 4): for kk in seq(0, 4): - Atile[ii, jj] += B[4 * i + ii, 4 * k + - kk] * B[4 * k + kk, 4 * j + jj] + Atile[ii, + jj] += B[ii + 4 * i, kk + + 4 * k] * B[kk + 4 * k, jj + 4 * j] for i0 in seq(0, 4): for i1 in seq(0, 4): A[i0 + 4 * i, i1 + 4 * j] = Atile[i0, i1] diff --git a/tests/golden/test_schedules/test_stage_mem_accum.txt b/tests/golden/test_schedules/test_stage_mem_accum.txt index 61b57b789..e8607d312 100644 --- a/tests/golden/test_schedules/test_stage_mem_accum.txt +++ b/tests/golden/test_schedules/test_stage_mem_accum.txt @@ -10,8 +10,9 @@ def sqmat(n: size, A: R[n, n] @ DRAM, B: R[n, n] @ DRAM): for ii in seq(0, 4): for jj in seq(0, 4): for kk in seq(0, 4): - Atile[ii, jj] += B[4 * i + ii, 4 * k + - kk] * B[4 * k + kk, 4 * j + jj] + Atile[ii, + jj] += B[ii + 4 * i, kk + + 4 * k] * B[kk + 4 * k, jj + 4 * j] for i0 in seq(0, 4): for i1 in seq(0, 4): A[i0 + 4 * i, i1 + 4 * j] += Atile[i0, i1] diff --git a/tests/golden/test_schedules/test_stage_mem_twice.txt b/tests/golden/test_schedules/test_stage_mem_twice.txt index 2345f6686..898a1bcec 100644 --- a/tests/golden/test_schedules/test_stage_mem_twice.txt +++ b/tests/golden/test_schedules/test_stage_mem_twice.txt @@ -6,7 +6,7 @@ def sqmat(n: size, A: R[n, n] @ DRAM, B: R[n, n] @ DRAM): B1: R[4, 4] for ii in seq(0, 4): for kk in seq(0, 4): - B1[ii, kk] = B[4 * i + ii, 4 * k + kk] + B1[ii, kk] = B[ii + 4 * i, kk + 4 * k] B2: R[4, 4] @ DRAM for i0 in seq(0, 4): for i1 in seq(0, 4): @@ -14,5 +14,5 @@ def sqmat(n: size, A: R[n, n] @ DRAM, B: R[n, n] @ DRAM): for ii in seq(0, 4): for jj in seq(0, 4): for kk in seq(0, 4): - A[4 * i + ii, - 4 * j + jj] += B1[ii, kk] * B2[kk, jj] + A[ii + 4 * i, + jj + 4 * j] += B1[ii, kk] * B2[kk, jj] diff --git a/tests/golden/test_window/test_normalize.txt b/tests/golden/test_window/test_normalize.txt index fb1a30d9d..a7a829d58 100644 --- a/tests/golden/test_window/test_normalize.txt +++ b/tests/golden/test_window/test_normalize.txt @@ -20,15 +20,13 @@ # define EXO_ASSUME(expr) ((void)(expr)) #endif +typedef struct c_code_str_Context { + +} c_code_str_Context; struct exo_win_1f32{ float *data; int_fast32_t strides[1]; }; -typedef struct c_code_str_Context { - -} c_code_str_Context; - - // proj( // n : size, // m : size, @@ -38,6 +36,7 @@ typedef struct c_code_str_Context { void proj( c_code_str_Context *ctxt, int_fast32_t n, int_fast32_t m, float* x, float* y ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -51,6 +50,14 @@ static int8_t _clamp_32to8(int32_t x) { #include +// dot( +// m : size, +// x : [f32][m] @DRAM, +// y : [f32][m] @DRAM, +// r : f32 @DRAM +// ) +void dot( c_code_str_Context *ctxt, int_fast32_t m, struct exo_win_1f32 x, struct exo_win_1f32 y, float* r ); + // dot( // m : size, // x : [f32][m] @DRAM, @@ -78,3 +85,4 @@ float y2; dot(ctxt,m,(struct exo_win_1f32){ (float*)&x[(1) * (m) + (0) * (1)], { 1 } },(struct exo_win_1f32){ (float*)&y[(0) * (n) + (2) * (1)], { n } },&xy); dot(ctxt,m,(struct exo_win_1f32){ (float*)&y[(0) * (n) + (3) * (1)], { n } },(struct exo_win_1f32){ (float*)&y[(0) * (n) + (3) * (1)], { n } },&y2); } + diff --git a/tests/golden/test_window/test_stride_assert.txt b/tests/golden/test_window/test_stride_assert.txt index 408c78807..253ed7cd4 100644 --- a/tests/golden/test_window/test_stride_assert.txt +++ b/tests/golden/test_window/test_stride_assert.txt @@ -20,15 +20,13 @@ # define EXO_ASSUME(expr) ((void)(expr)) #endif +typedef struct c_code_str_Context { + +} c_code_str_Context; struct exo_win_2i8{ int8_t *data; int_fast32_t strides[2]; }; -typedef struct c_code_str_Context { - -} c_code_str_Context; - - // stride_assert( // n : size, // m : size, @@ -38,6 +36,7 @@ typedef struct c_code_str_Context { void stride_assert( c_code_str_Context *ctxt, int_fast32_t n, int_fast32_t m, struct exo_win_2i8 src, struct exo_win_2i8 dst ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -51,6 +50,7 @@ static int8_t _clamp_32to8(int32_t x) { #include + // stride_assert( // n : size, // m : size, @@ -69,3 +69,4 @@ for (int i = 0; i < n; i++) { } } } + diff --git a/tests/golden/test_window/test_window.txt b/tests/golden/test_window/test_window.txt index 72374035a..7fd7a404e 100644 --- a/tests/golden/test_window/test_window.txt +++ b/tests/golden/test_window/test_window.txt @@ -20,15 +20,13 @@ # define EXO_ASSUME(expr) ((void)(expr)) #endif +typedef struct c_code_str_Context { + +} c_code_str_Context; struct exo_win_2i8{ int8_t *data; int_fast32_t strides[2]; }; -typedef struct c_code_str_Context { - -} c_code_str_Context; - - // window( // n : size, // m : size, @@ -38,6 +36,7 @@ typedef struct c_code_str_Context { void window( c_code_str_Context *ctxt, int_fast32_t n, int_fast32_t m, struct exo_win_2i8 src, struct exo_win_2i8 dst ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -51,6 +50,7 @@ static int8_t _clamp_32to8(int32_t x) { #include + // window( // n : size, // m : size, @@ -66,3 +66,4 @@ for (int i = 0; i < n; i++) { } } } + diff --git a/tests/golden/test_window/test_window_stmt.txt b/tests/golden/test_window/test_window_stmt.txt index dc23c2f25..55cd2878d 100644 --- a/tests/golden/test_window/test_window_stmt.txt +++ b/tests/golden/test_window/test_window_stmt.txt @@ -20,15 +20,13 @@ # define EXO_ASSUME(expr) ((void)(expr)) #endif +typedef struct c_code_str_Context { + +} c_code_str_Context; struct exo_win_1f32{ float *data; int_fast32_t strides[1]; }; -typedef struct c_code_str_Context { - -} c_code_str_Context; - - // window_stmt( // n : size, // m : size, @@ -37,6 +35,7 @@ typedef struct c_code_str_Context { void window_stmt( c_code_str_Context *ctxt, int_fast32_t n, int_fast32_t m, float* x ); + static int _floor_div(int num, int quot) { int off = (num>=0)? 0 : quot-1; return (num-off)/quot; @@ -50,6 +49,7 @@ static int8_t _clamp_32to8(int32_t x) { #include + // window_stmt( // n : size, // m : size, @@ -63,3 +63,4 @@ for (int i = 0; i < n; i++) { } free(z); } + diff --git a/tests/test_apps.py b/tests/test_apps.py new file mode 100644 index 000000000..add219762 --- /dev/null +++ b/tests/test_apps.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from pathlib import Path + +import exo +import exo.main + +REPO_ROOT = Path(__file__).parent.parent.resolve() + + +def _test_app(module_file): + mod = exo.main.load_user_code(module_file) + procs = exo.main.get_procs_from_module(mod) + + c_file, h_file = exo.compile_procs_to_strings(procs, 'test_case.h') + + return f'{h_file}\n{c_file}' + + +def test_x86_sgemm(golden): + module_file = REPO_ROOT / 'apps' / 'x86_demo' / 'sgemm' / 'sgemm.py' + assert _test_app(module_file.resolve(strict=True)) == golden + + +def test_x86_conv(golden): + module_file = REPO_ROOT / 'apps' / 'x86_demo' / 'conv' / 'conv.py' + assert _test_app(module_file.resolve(strict=True)) == golden + + +def test_neon_sgemm(golden): + module_file = REPO_ROOT / 'apps' / 'neon_dev' / 'sgemm' / 'sgemm.py' + assert _test_app(module_file.resolve(strict=True)) == golden