Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change python kernel interface to accept CBs rather than Tensors #1767

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions test/python/simple_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@ class TTKernelBuilder(ast.NodeVisitor):
"add": ttkernel.add,
}

def __init__(self, name, arg_shapes, arg_dtypes):
def __init__(self, name, CBIndices):
self.ctx = Context()
self.cursor = Location.unknown(self.ctx)
self.module = Module.create(self.cursor)
self.insert_point = self.module.body
self.name = name
assert len(arg_shapes) == len(arg_dtypes)
self.arg_shapes = arg_shapes
self.arg_dtypes = arg_dtypes
self.arg_tile_shapes = [(32, 32) for i in range(len(arg_shapes))]
self.arg_tile_shapes = [(32, 32) for i in range(len(CBIndices))]
self.symbol_table = {}
self.int_constant_map = {}
ttkernel.register_dialect(self.ctx)
Expand All @@ -55,7 +52,6 @@ def get_constant(self, value):
raise NotImplementedError(f"get_constant {value} not implemented")

def get_tilized_memref_for_arg(self, idx):
arg_shape = list(self.arg_shapes[idx])
arg_dtype = self.arg_dtypes[idx]
arg_tile_shape = list(self.arg_tile_shapes[idx])
assert len(arg_shape) >= 2
Expand All @@ -64,7 +60,7 @@ def get_tilized_memref_for_arg(self, idx):
element_type = tt.ir.TileType.get(
self.ctx, arg_tile_shape[-2], arg_tile_shape[-1], arg_dtype
)
return MemRefType.get(arg_shape, element_type)
return MemRefType.get(1, element_type)

def emit_entry_func(self, node):
assert isinstance(node, ast.FunctionDef)
Expand Down Expand Up @@ -217,10 +213,9 @@ def tile_regs_release(self):
pass


class Tensor:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
class CB:
def __init__(self, index):
self.index = index


def to_data_type(dtype):
Expand All @@ -233,10 +228,9 @@ def to_data_type(dtype):
def ttkernel_compile(f):
@functools.wraps(f)
def _wrapper(*args, **kwargs):
arg_shapes = [tuple(arg.shape) for arg in args]
arg_dtypes = [to_data_type(arg.dtype) for arg in args]
arg_cbindex = [arg.index for arg in args]
m = ast.parse(inspect.getsource(f))
b = TTKernelBuilder(f.__name__, arg_shapes, arg_dtypes)
b = TTKernelBuilder(f.__name__, arg_cbindex)
# print(ast.dump(m, indent=4))
b.visit(m)
# CHECK: "func.func"[[C:.*]]
Expand Down Expand Up @@ -301,7 +295,7 @@ def eltwise(
# return in0 + in1


a = Tensor((8, 128, 128), "float32")
b = Tensor((8, 128, 128), "float32")
out = Tensor((8, 128, 128), "float32")
eltwise(a, b, out)
cb_a = CB(0)
cb_b = CB(1)
cb_out = CB(2)
eltwise(cb_a, cb_b, cb_out)
Loading