Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop IR Interpreter #741

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d98df2c
Adds Neon ISA example
rtzam Aug 23, 2024
75deca4
🤖 apply linter changes (will not trigger CI)
rtzam Aug 23, 2024
1ca41f6
Merge branch 'main' into neon-example
yamaguchi1024 Aug 28, 2024
672a2ac
Merge branch 'main' into neon-example
gilbo Sep 3, 2024
ec6ad38
Merge branch 'neon-example' of github.com:exo-lang/exo into neon-example
Oct 4, 2024
c6aba82
dump old interpreter code and write matmul example that invokes the i…
Oct 15, 2024
312daf9
stencil missing LoopIR cases in the interpreter
meganfrisella Oct 16, 2024
6f3bbc2
interpreter additions: precondition checks, window expressions (inclu…
meganfrisella Oct 17, 2024
72f9352
test suite comparing interpreter to compiler. hits most statements an…
meganfrisella Oct 23, 2024
2cd89f1
strides + free + some env map changes
andrewdalex Oct 25, 2024
dc6f132
write more stride tests, interpreting LoopIR.StrideExpr should return…
meganfrisella Oct 29, 2024
bc0853f
add tests highlighting discrepancy between compiler and interpreter o…
meganfrisella Oct 29, 2024
bf4e4cb
impl LoopIR.WindowStmt in the interpreter. add more window tests. add…
meganfrisella Oct 30, 2024
9648851
typecheck input buffer
andrewdalex Oct 30, 2024
6e1b34a
fix WindowExpr interpreter bug (failed to handle LoopIR.Point case), …
meganfrisella Oct 30, 2024
6ba2a60
Merge branch 'interpret' of github.com:exo-lang/exo into interpret
meganfrisella Oct 30, 2024
5663337
Adds Neon ISA example
rtzam Aug 23, 2024
4edce2a
🤖 apply linter changes (will not trigger CI)
rtzam Aug 23, 2024
04d9e07
dump old interpreter code and write matmul example that invokes the i…
Oct 15, 2024
1719f3e
stencil missing LoopIR cases in the interpreter
meganfrisella Oct 16, 2024
b2afe25
interpreter additions: precondition checks, window expressions (inclu…
meganfrisella Oct 17, 2024
b8437f6
test suite comparing interpreter to compiler. hits most statements an…
meganfrisella Oct 23, 2024
a770dab
strides + free + some env map changes
andrewdalex Oct 25, 2024
143b790
write more stride tests, interpreting LoopIR.StrideExpr should return…
meganfrisella Oct 29, 2024
f1294aa
add tests highlighting discrepancy between compiler and interpreter o…
meganfrisella Oct 29, 2024
eee0b09
impl LoopIR.WindowStmt in the interpreter. add more window tests. add…
meganfrisella Oct 30, 2024
c59781f
fix WindowExpr interpreter bug (failed to handle LoopIR.Point case), …
meganfrisella Oct 30, 2024
27ae79e
typecheck input buffer
andrewdalex Oct 30, 2024
dd8784c
refactor interpreter; fix C division bug; fix some of the failing tes…
andrewdalex Nov 3, 2024
cb07546
🤖 apply linter changes (will not trigger CI)
andrewdalex Nov 6, 2024
bf5c6f5
remove extra files
andrewdalex Nov 7, 2024
f01b541
Revert "remove extra files"
andrewdalex Nov 7, 2024
83063e9
fix deleted stuff
andrewdalex Nov 7, 2024
d26ef82
more unnecessary changes
andrewdalex Nov 7, 2024
d64adac
clean up
meganfrisella Nov 12, 2024
2ab99de
merge
meganfrisella Nov 12, 2024
9bd7383
🤖 apply linter changes (will not trigger CI)
meganfrisella Nov 12, 2024
430e6de
debug test_interp.py - all tests expected to pass do pass
meganfrisella Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ configure

dependencies/chipyard

examples/*.c
examples/*.d
examples/*.h

.vscode
117 changes: 117 additions & 0 deletions examples/arm_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import os
import sys

from exo import proc
from exo.platforms.neon import *
from exo.stdlib.scheduling import *

# Hide output when running through exocc.
if __name__ != "__main__" and hasattr(os, "devnull"):
sys.stdout = open(os.devnull, "w")

# Algorithm definition
@proc
def rank_k_reduce_6x16(
K: size, A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM
):
for i in seq(0, 6):
for j in seq(0, 16):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]


print("\n============= original ==============")
print(rank_k_reduce_6x16)

print("\n============= reorder loops ==============")
neon = rename(rank_k_reduce_6x16, "rank_k_reduce_6x16_scheduled")
neon = reorder_loops(neon, "j k")
neon = reorder_loops(neon, "i k")
print(neon)


print("\n============= divide loop ==============")
# neon only supports vectors of width 4 for f32
# x86 supports either 4 or 8 wide
vec_reg_width = 4
neon = divide_loop(neon, "for j in _: _", vec_reg_width, ["jo", "ji"], perfect=True)
print(neon)

print("\n============= stage mem ==============")
# we want the computation to be "output stationary", which means,
# we want to preallocate all the output registers at the start.
# The staging of C will cause us to consume 12 out of the 16 vector registers
neon = stage_mem(neon, "for k in _:_", "C[0:6, 0:16]", "C_reg")
print(neon)
neon = simplify(neon)

print("\n============= reshape C_reg ==============")
# Reshape C_reg so we can map it into vector registers
neon = divide_dim(neon, "C_reg:_", 1, vec_reg_width)
print(neon)

print("\n============= divide loop ==============")
neon = repeat(divide_loop)(
neon, "for i1 in _: _", vec_reg_width, ["i2", "i3"], perfect=True
)
neon = simplify(neon)
print(neon)

print("\n============= map C_reg ops ==============")
# Map C_reg operations to vector instructions
neon = set_memory(neon, "C_reg:_", Neon)
# this loads 8 items into the register but neon only loads 4
neon = replace_all(neon, neon_vld_4xf32)
neon = replace_all(neon, neon_vst_4xf32)
neon = simplify(neon)
print(neon)


# Now, the rest of the compute needs to work with the constraint that the
# we only have 4 more registers to work with here.

print("\n============= stage B_reg ==============")
# B is easy, it is just two vector loads
neon = stage_mem(neon, "for i in _:_", "B[k, 0:16]", "B_reg")
neon = simplify(neon)
print(neon)

print("\n============= block 1st B_reg load ==============")
neon = divide_loop(neon, "for i0 in _: _ #1", vec_reg_width, ["io", "ii"], perfect=True)
print(neon)

print("\n============= reshape B_reg ==============")
neon = divide_dim(neon, "B_reg:_", 0, vec_reg_width)
print(neon)

print("\n============= map B_reg ops ==============")
neon = set_memory(neon, "B_reg:_", Neon)
neon = simplify(neon)
neon = replace_all(neon, neon_vld_4xf32)
neon = simplify(neon)
print(neon)

# Now we've used up two more vector registers.
# The final part is staging A

print("\n============= stage A_reg ==============")
neon = bind_expr(neon, "A[i, k]", "A_reg")
neon = expand_dim(neon, "A_reg", vec_reg_width, "ji")
neon = lift_alloc(neon, "A_reg", n_lifts=2)
neon = fission(neon, neon.find("A_reg[ji] = _").after(), n_lifts=2)
neon = remove_loop(neon, "for jo in _: _")
neon = set_memory(neon, "A_reg:_", Neon)
neon = replace_all(neon, neon_broadcast_4xf32)
neon = simplify(neon)
print(neon)


# DO THE COMPUTE!!!
print("\n============= map mult add op ==============")
neon = replace_all(neon, neon_vfmadd_4xf32_4xf32)
neon = simplify(neon)
print(neon)

print("\n============= dnone! ==============")
20 changes: 16 additions & 4 deletions examples/avx2_matmul/Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
CFLAGS ?= -march=native

avx2_matmul: avx2_matmul.o main.o
.PHONY: x86
x86: avx2_matmul

avx2_matmul.c: x86_matmul.py
# x86 build
avx2_matmul: avx2_matmul.o main.o
avx2_matmul.h avx2_matmul.c: x86_matmul.py
exocc -o . --stem $(*F) $^

main.c: avx2_matmul.c
.PHONY: neon
neon: neon_matmul

# ARM
neon_matmul: neon_matmul.o main.o
neon_matmul.h neon_matmul.c: arm_matmul.py
exocc -o . --stem $(*F) $^

.PHONY: clean
clean:
$(RM) avx2_matmul avx2_matmul.* *.o exo_demo
$(RM) *.o exo_demo
$(RM) -r __pycache__/
$(RM) avx2_matmul avx2_matmul.*
$(RM) neon_matmul neon_matmul.*

9 changes: 8 additions & 1 deletion examples/avx2_matmul/main.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#include <stdint.h>
#include <stdio.h>
#include <time.h>

#include "avx2_matmul.h"
// generated from exo
void rank_k_reduce_6x16(
void *ctxt, int_fast32_t K, const float *A, const float *B, float *C);
void rank_k_reduce_6x16_scheduled(
void *ctxt, int_fast32_t K, const float *A, const float *B, float *C);

#define K 2048
static float A[6 * K];
Expand Down Expand Up @@ -31,6 +36,8 @@ int main() {
clock_t start, end;
int msec;

initialize();

// Calling original matmul
start = clock();
for (int i = 0; i < 1000; i++)
Expand Down
62 changes: 62 additions & 0 deletions examples/matmul_interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import os
import sys
import numpy as np

from exo import proc
from exo.platforms.neon import *
from exo.stdlib.scheduling import *

# Hide output when running through exocc.
if __name__ != "__main__" and hasattr(os, "devnull"):
sys.stdout = open(os.devnull, "w")


@proc
def foo(s: f32, arg: f32[1, 1] @ DRAM):
arg[0, 0] = s


# Algorithm definition
@proc
def rank_k_reduce_6x16(
M: size,
K: size,
N: size,
A: f32[M, K] @ DRAM,
B: f32[K, N] @ DRAM,
C: f32[M, N] @ DRAM,
test: f32 @ DRAM,
):
s: f32
buf: f32[1, 1]
s = 4

for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
s: f32
s = 2
s = s + 1
test = s


@proc
def check_stride(A: [f32][6] @ DRAM, res: f32 @ DRAM):
assert stride(A, 0) == 2
for i in seq(0, 6):
res += A[i]


# M = 2; K = 2; N = 2
# A = np.zeros(M*K, dtype=float).reshape((M,K))
# B = np.arange(K*N, dtype=float).reshape((K,N))
# C = np.zeros(M*N, dtype=float).reshape((M,N))
res = np.zeros(1)

A = np.array([1.0] * 12)
# rank_k_reduce_6x16.interpret(M=M, K=K, N=N, A=A, B=B, C=C, test=res)
check_stride.interpret(A=A[::2], res=res)
print(res)
4 changes: 4 additions & 0 deletions rebuild.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#! /bin/bash

python -m build .
pip install --force-reinstall dist/*.whl
5 changes: 5 additions & 0 deletions rebuild_and_test_interp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#! /bin/bash

python -m build .
pip install --force-reinstall dist/*.whl
python3 examples/matmul_interp.py
4 changes: 4 additions & 0 deletions src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .frontend.pattern_match import match_pattern
from .core.prelude import *
from .rewrite.new_eff import Check_Aliasing
from .backend.LoopIR_interpreter import run_interpreter

# Moved to new file
from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc
Expand Down Expand Up @@ -302,6 +303,9 @@ def c_code_str(self):
def compile_c(self, directory: Path, filename: str):
compile_procs([self], directory, f"{filename}.c", f"{filename}.h")

def interpret(self, **kwargs):
run_interpreter(self._loopir_proc, kwargs)

# ------------------------------- #
# scheduling operations
# ------------------------------- #
Expand Down
Loading
Loading