Skip to content
This repository has been archived by the owner on May 5, 2024. It is now read-only.

Commit

Permalink
support memref::subview
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 13, 2024
1 parent 86f0a37 commit e9b5383
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 49 deletions.
1 change: 1 addition & 0 deletions openhls/compiler/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def parfor(**kwargs):

def wrapper(body):
for args in itertools.product(*kwargs):
print(f"{args=}")
idx = tuple(i for arg, i in args)
pe_idx = extend_idx(idx)
state.state.update_current_pe_idx(pe_idx=pe_idx)
Expand Down
22 changes: 12 additions & 10 deletions openhls/compiler/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

import networkx as nx
from contextlib import contextmanager
from threading import RLock

from openhls.config import VAL_PREFIX, DTYPE, DEBUG, INCLUDE_AUX_DEPS
from openhls.util import extend_idx
Expand All @@ -17,7 +17,6 @@
class State:
_var_count = 0
_op_call_count = 0
op_graph = nx.MultiDiGraph()
cst_map = {}
cst_count = 0
_pe_idx = (0,)
Expand All @@ -26,12 +25,17 @@ class State:
pe_idx_to_most_recent_op_id = {}
op_id_to_pe_idx = {}
pe_deps = set()
rlock = None

def __init__(self, output_file):
self.op_graph.add_nodes_from(
[INPUT_ARG, MEMREF_ARG, GLOBAL_MEMREF_ARG, CONSTANT]
)
self.output_file = output_file
self.rlock = RLock()

@contextmanager
def with_rlock(self):
self.rlock.acquire()
yield
self.rlock.release()

def incr_var(self):
self._var_count += 1
Expand Down Expand Up @@ -70,12 +74,10 @@ def add_op_res(self, v, op):
self.val_source[v] = op

def maybe_add_op(self, op):
if op not in self.op_graph.nodes:
self.op_graph.add_node(op)
pass

def add_edge(self, op, arg, out_v):
val_source = self.get_arg_src(arg)
self.op_graph.add_edge(val_source, op, input=arg, output=out_v, id=op.op_id)
pass

def update_most_recent_pe_idx(self, pe_idx, op):
self.pe_idx_to_most_recent_op_id[pe_idx] = op.op_id
Expand Down
14 changes: 12 additions & 2 deletions openhls/ir/memref.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass
from typing import Tuple

Expand Down Expand Up @@ -81,9 +82,18 @@ def reduce_add(self):
def reduce_max(self):
return ReduceMax(list(self.registers.flatten()))

def alias(self, other_memref):
def alias(self, other_memref, offsets=None, sizes=None, strides=None):
assert isinstance(other_memref, MemRef)
self.registers = other_memref.registers
if offsets is not None and sizes is not None and strides is not None:
subview = []
for o, si, st in zip(offsets, sizes, strides):
subview.append(slice(o, o + si, st))
print("subview", subview, file=sys.stderr)
print("before subview", self.registers.shape, file=sys.stderr)
self.registers = other_memref.registers[tuple(subview)]
print("aftier subview", self.registers.shape, file=sys.stderr)
else:
self.registers = other_memref.registers


class GlobalMemRef:
Expand Down
35 changes: 22 additions & 13 deletions openhls_translate/EmitHLSPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class ModuleEmitter : public OpenHLSEmitterBase {
void emitLoad(memref::LoadOp op);
void emitStore(memref::StoreOp op);
void emitMemCpy(memref::CopyOp op);
void emitMemSubview(memref::SubViewOp op);
void emitGlobal(memref::GlobalOp op);
void emitGetGlobal(memref::GetGlobalOp op);
void emitTensorStore(memref::TensorStoreOp op);
Expand Down Expand Up @@ -420,6 +421,7 @@ class StmtVisitor : public HLSVisitorBase<StmtVisitor, bool> {
bool visitOp(memref::StoreOp op) { return emitter.emitStore(op), true; }
bool visitOp(memref::DeallocOp op) { return true; }
bool visitOp(memref::CopyOp op) { return emitter.emitMemCpy(op), true; }
bool visitOp(memref::SubViewOp op) { return emitter.emitMemSubview(op), true; }
bool visitOp(memref::GlobalOp op) { return emitter.emitGlobal(op), true; }
bool visitOp(memref::GetGlobalOp op) {
return emitter.emitGetGlobal(op), true;
Expand Down Expand Up @@ -1169,33 +1171,40 @@ void ModuleEmitter::emitStore(memref::StoreOp op) {
}

void ModuleEmitter::emitMemCpy(memref::CopyOp op) {
// indent() << "memcpy(";
indent() << "";
// emitValue(op.target());
// os << " = ";
emitValue(op.target());
os << ".alias(";
emitValue(op.getSource());
os << ")";
// os << ", ";
os << "\n";
}

// auto type = op.target().getType().cast<MemRefType>();
// os << type.getNumElements() << " * sizeof(" << getTypeName(op.target())
// << "))";
// os << "\n";
void ModuleEmitter::emitMemSubview(memref::SubViewOp op) {
indent() << "";
emitArrayDecl(op.getResult());
os << "\n";
indent() << "";
emitValue(op.result());
os << ".alias(";
emitValue(op.getSource());
os << ", offsets=" << op.getStaticOffsets();
os << ", sizes=" << op.getStaticSizes();
os << ", strides=" << op.getStaticStrides();
os << ")";
os << "\n";
}

void ModuleEmitter::emitGlobal(memref::GlobalOp op) {
auto initial_val = op.initial_value();
auto elem = initial_val->dyn_cast<DenseFPElementsAttr>();
os << op.sym_name().str() << " = np.array([";
for (const auto &item : elem.getValues<FloatAttr>())
os << item.getValueAsDouble() << ", ";
os << "]).reshape(";

os << op.sym_name().str() << " = np.full((";
for (const auto &item : elem.getType().getShape())
os << item << ", ";
os << "), ";
for (const auto &item : elem.getValues<FloatAttr>()) {
os << item.getValueAsDouble();
break;
}
os << ")\n";
}

Expand Down
3 changes: 2 additions & 1 deletion openhls_translate/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HLSVisitorBase {
// Memref-related statements.
memref::AllocOp, memref::AllocaOp, memref::LoadOp, memref::StoreOp,
memref::GlobalOp, memref::GetGlobalOp,
memref::DeallocOp, memref::CopyOp, memref::TensorStoreOp,
memref::DeallocOp, memref::CopyOp, memref::SubViewOp, memref::TensorStoreOp,
tensor::ReshapeOp, memref::ReshapeOp, memref::CollapseShapeOp,
memref::ExpandShapeOp, memref::ReinterpretCastOp,
bufferization::ToMemrefOp, bufferization::ToTensorOp,
Expand Down Expand Up @@ -132,6 +132,7 @@ class HLSVisitorBase {
HANDLE(memref::GetGlobalOp);
HANDLE(memref::DeallocOp);
HANDLE(memref::CopyOp);
HANDLE(memref::SubViewOp);
HANDLE(memref::TensorStoreOp);
HANDLE(tensor::ReshapeOp);
HANDLE(memref::ReshapeOp);
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"wheel",
"cmake==3.21",
"cmake>=3.24",
# MLIR build depends.
"ninja",
"numpy==1.23.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ numpy
networkx
astor
jinja2
cocotb==1.6.2
cocotb
matplotlib
xeda
16 changes: 3 additions & 13 deletions scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ if [ ! -f "${OPENHLS_DIR}"/build/llvm/CMakeCache.txt ]; then
-DCMAKE_BUILD_TYPE=DEBUG \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
-DPython3_FIND_VIRTUALENV=ONLY \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-S "${OPENHLS_DIR}"/externals/llvm-project/llvm \
-B "${OPENHLS_DIR}"/build/llvm
Expand Down Expand Up @@ -137,7 +138,7 @@ if [ ! -f "${OPENHLS_DIR}"/build/flopoco_converter/CMakeCache.txt ]; then
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-S "${OPENHLS_DIR}"/flopoco_convert_ext \
-S "${OPENHLS_DIR}"/extensions/flopoco_convert_ext \
-B "${OPENHLS_DIR}"/build/flopoco_converter
fi

Expand All @@ -154,15 +155,4 @@ if [ ! -f "${OPENHLS_DIR}"/build/ghdl/bin/ghdl ]; then
mkdir -p "${OPENHLS_DIR}"/build/ghdl
tar -xvf ghdl-gha-ubuntu-20.04-llvm.tgz -C "${OPENHLS_DIR}"/build/ghdl
fi
fi


# TODO
#PYBIND11_DIR=${PREFIX}/lib/python3.10/site-packages/pybind11/share/cmake/
#PYBIND11_DIR=$(python -c "import pybind11; print(pybind11.get_cmake_dir())")
#-DPYTHON_LIBRARY="/Users/mlevental/miniforge3/envs/openhls/lib/libpython3.10.dylib" -DPYTHON_INCLUDE_DIR="/Users/mlevental/miniforge3/envs/openhls/include/python3.10" \

# -DPYTHON_INCLUDE_DIR="$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")" \
# -DPYTHON_LIBRARY="$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")" \

#-Dpybind11_DIR=/home/mlevental/miniconda3/envs/openhls/lib/python3.10/site-packages/pybind11/share/cmake/pybind11 -DPython_EXECUTABLE=/home/mlevental/miniconda3/envs/openhls/bin/python
fi
9 changes: 1 addition & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,6 @@ def build_torch_mlir(base_cmake_args):
)


def install_torch_mlir_from_wheel():
torch_mlir_wheel = get_latest_torch_mlir()
subprocess.check_call(
[sys.executable, "-m", "pip", "install", torch_mlir_wheel],
cwd=CWD,
)


def build_circt(base_cmake_args):
circt_dir = os.path.join(EXTERNALS, "circt")
circt_build_dir = os.path.join(ROOT_BUILD_DIR, "circt")
Expand Down Expand Up @@ -168,6 +160,7 @@ def build_openhls_translate(base_cmake_args):
f'-DMLIR_DIR={os.path.join(LLVM_BUILD_DIR, "lib", "cmake", "mlir")}',
f'-DLLVM_DIR={os.path.join(LLVM_BUILD_DIR, "lib", "cmake", "llvm")}',
"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
"-DLLVM_ENABLE_ABI_BREAKING_CHECKS=OFF"
f"-Dpybind11_DIR={pybind11.get_cmake_dir()}",
]
run_cmake(openhls_dir, cmake_args, openhls_build_dir, target="openhls_translate")
Expand Down

0 comments on commit e9b5383

Please sign in to comment.