Skip to content

Commit b94afca

Browse files
Address comments
1 parent b39a7f4 commit b94afca

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

python/tvm/contrib/ethosu/cascader/device_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def __init__(self, shape: List[int], layout="NHWC"):
5555
self.width = int(shape[2])
5656
self.depth = int(shape[3])
5757
elif length == 3:
58-
self.height = int(shape[1])
59-
self.width = int(shape[2])
60-
self.depth = 1
58+
self.height = int(shape[0])
59+
self.width = int(shape[1])
60+
self.depth = int(shape[2])
6161
elif length == 2:
6262
self.height = int(shape[0])
6363
self.width = int(shape[1])

python/tvm/relay/backend/contrib/ethosu/te/identity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def match_ethosu_identity(output_tensor, device_config):
131131
ofm_dtype = output_tensor.dtype
132132

133133
input_tensors_shape = input_tensors[0].shape
134-
ifm_channels = int(input_tensors_shape[3] if len(input_tensors_shape) > 3 else 1)
135-
ofm_channels = ifm_channels
134+
length = len(input_tensors_shape)
135+
channels = int(input_tensors_shape[length - 1]) if length >= 3 else 1
136136

137137
subkernels = len(device_config.get_kernel_steps(identity.op.name, 1, 1, ifm_dtype))
138138

@@ -143,8 +143,8 @@ def match_ethosu_identity(output_tensor, device_config):
143143
propagators[0],
144144
identity.op.attrs,
145145
output_tensor.shape,
146-
ofm_channels,
147-
ifm_channels,
146+
channels,
147+
channels,
148148
output_layout,
149149
input_layout,
150150
ifm_dtype,

0 commit comments

Comments
 (0)