Skip to content

Commit

Permalink
revert changes to python script
Browse files Browse the repository at this point in the history
  • Loading branch information
TT-billteng committed Jan 24, 2025
1 parent 65af5d6 commit ebf68b2
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions test/python/simple_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,54 +239,69 @@ def _wrapper(*args, **kwargs):
b = TTKernelBuilder(f.__name__, arg_shapes, arg_dtypes)
# print(ast.dump(m, indent=4))
b.visit(m)
# CHECK: "func.func"[[C:.*]]
# CHECK: %[[C:.*]] = "arith.constant"[[C:.*]]
# CHECK: "scf.for"[[C:.*]]
# CHECK: "ttkernel.cb_wait_front"[[C:.*]]
# CHECK: "ttkernel.cb_reserve_back"[[C:.*]]
# CHECK: "ttkernel.tile_regs_acquire"[[C:.*]]
# CHECK: "ttkernel.unpack_ab"[[C:.*]]
# CHECK: "ttkernel.add"[[C:.*]]
# CHECK: "ttkernel.pack"[[C:.*]]
# CHECK: "ttkernel.tile_regs_release"[[C:.*]]
# CHECK: "ttkernel.cb_pop_front"[[C:.*]]
# CHECK: "ttkernel.cb_push_back"[[C:.*]]
print(b.module)
# return f(*args, **kwargs)

return _wrapper


@ttkernel_compile
def eltwise(in0, in1, out):
# CHECK: "func.func"[[C:.*]]
def eltwise(
in0,
in1,
out,
index_maps=[
lambda *dn, m, n: (*dn, m, n),
lambda *dn, m, n: (*dn, m, n),
lambda *dn, m, n: (*dn, m, n),
],
iterator_types=["parallel", "parallel", "parallel"],
dynamic_shapes=False,
):
t6 = Tensix(in0, in1, out)
# CHECK: %[[C:.*]] = "arith.constant"[[C:.*]]
# CHECK: %[[C:.*]] = "arith.constant"[[C:.*]]
# CHECK: %[[C:.*]] = "arith.constant"[[C:.*]]
# CHECK: "scf.for"[[C:.*]]
for dn in range(in0.shape[-3]):
# CHECK: "scf.for"[[C:.*]]
for m in range(in0.shape[-2]):
# CHECK: "scf.for"[[C:.*]]
for n in range(in0.shape[-1]):
# CHECK: "ttkernel.cb_wait_front"[[C:.*]]
in0.wait()
# CHECK: "ttkernel.cb_wait_front"[[C:.*]]
in1.wait()
# CHECK: "ttkernel.cb_reserve_back"[[C:.*]]
out.reserve()
# CHECK: "ttkernel.tile_regs_acquire"[[C:.*]]
t6.tile_regs_acquire()
# CHECK: "ttkernel.unpack_ab"[[C:.*]]
t6.unpack_ab(in0, 0, in1, 0)
# CHECK: "ttkernel.add"[[C:.*]]
t6.add(0)
# CHECK: "ttkernel.pack"[[C:.*]]
t6.pack(0, out, 0)
# CHECK: "ttkernel.tile_regs_release"[[C:.*]]
t6.tile_regs_release()
# CHECK: "ttkernel.cb_pop_front"[[C:.*]]
in0.pop()
# CHECK: "ttkernel.cb_pop_front"[[C:.*]]
in1.pop()
# CHECK: "ttkernel.cb_push_back"[[C:.*]]
out.push()


def test_eltwise():
a = Tensor((8, 128, 128), "float32")
b = Tensor((8, 128, 128), "float32")
out = Tensor((8, 128, 128), "float32")
eltwise(a, b, out)


test_eltwise()
# @ttkernel_generic
# def eltwise_generic(
# in0,
# in1,
# index_maps=[
# lambda *dn, m, n: (*dn, m, n),
# lambda *dn, m, n: (*dn, m, n),
# lambda *dn, m, n: (*dn, m, n),
# ],
# iterator_types=["parallel", "parallel", "parallel"],
# ):
# return in0 + in1


a = Tensor((8, 128, 128), "float32")
b = Tensor((8, 128, 128), "float32")
out = Tensor((8, 128, 128), "float32")
eltwise(a, b, out)

0 comments on commit ebf68b2

Please sign in to comment.