Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_binary_elementwise_params(
-------
SerialBinaryElementwise
The parameters needed to construct a binary elementwise operator.
output_pointer : tvm.tir.Var
output_buffer : tvm.tir.Buffer
The output pointer of the binary elementwise operation.
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
Expand All @@ -56,17 +56,17 @@ def get_binary_elementwise_params(
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
# loads = [input, input, LUT, LUT]
loads = get_loads(inner)
input_pointer = loads[0].buffer.data
input_pointer1 = loads[1].buffer.data
input_buffer = loads[0].buffer
input_buffer1 = loads[1].buffer

if reversed_operands:
input_pointer, input_pointer1 = input_pointer1, input_pointer
output_pointer = inner.buffer.data
input_buffer, input_buffer1 = input_buffer1, input_buffer
output_buffer = inner.buffer
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
serial_ifm, _ = get_ifm_params(input_buffer, producers_consumers, stmt)
serial_ifm2, _ = get_ifm_params(input_buffer1, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params(
output_buffer, producers_consumers, stmt
)
# Get activation info
serial_activation = SerialActivation(
Expand All @@ -87,7 +87,7 @@ def get_binary_elementwise_params(
block_config=serial_block_config,
rescale_config=rescale_config,
),
output_pointer,
replace_pointer,
output_buffer,
replace_buffer,
is_allocator,
)
18 changes: 9 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def get_conv2d_params(stmt, producers_consumers):
-------
Serial2DConvolution
The parameters needed to construct a 2D convolution.
output_pointer : tvm.tir.Var
output_buffer : tvm.tir.Buffer
The output pointer of the convolution operation.
replace_pointer : tvm.tir.Var
replace_buffer : tvm.tir.Buffer
The output pointer of the DMA write operation, which is to replace
the convolution output pointer.
is_allocator : bool
Expand All @@ -60,12 +60,12 @@ def get_conv2d_params(stmt, producers_consumers):
loads = get_loads(rc.body)
# stores = [output]
stores = get_stores(rc.body)
input_pointer = loads[1].buffer.data
output_pointer = stores[0].buffer.data
input_buffer = loads[1].buffer
output_buffer = stores[0].buffer
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params(
output_buffer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_conv2d_params(stmt, producers_consumers):
upscale=attrs["upscale"],
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
output_buffer,
replace_buffer,
is_allocator,
)
24 changes: 12 additions & 12 deletions python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def get_depthwise_conv2d_params(
-------
Serial2DDepthwise
The parameters needed to construct a 2D depthwise.
output_pointer : tvm.tir.Var
The output pointer of the convolution operation.
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the convolution output pointer.
output_buffer : tvm.tir.Buffer
The output buffer of the convolution operation.
replace_buffer : tvm.tir.Buffer
The output buffer of the DMA write operation, which is to replace
the convolution output buffer.
is_allocator : bool
Whether this operator allocates its output.

Expand All @@ -64,12 +64,12 @@ def get_depthwise_conv2d_params(
loads = get_loads(rw.body)
# stores = [output]
stores = get_stores(rw.body)
input_pointer = loads[1].buffer.data
output_pointer = stores[0].buffer.data
input_buffer = loads[1].buffer
output_buffer = stores[0].buffer
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params(
output_buffer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_depthwise_conv2d_params(
upscale="NONE",
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
output_buffer,
replace_buffer,
is_allocator,
)
Loading