Skip to content

Commit 119afda

Browse files
authored
[microNPU] add E2E tests with cascader wo striping (#11410)
This commit adds end-to-end tests using the cascader w/o striping. It needed few adjustments to the order in which the arugments are provided to the entry point function in AoT when both memory pools and devices are present. Change-Id: I37e04afd635add895e317586f628a62cae75f3fa
1 parent c6415d1 commit 119afda

File tree

14 files changed

+354
-138
lines changed

14 files changed

+354
-138
lines changed

python/tvm/contrib/ethosu/cascader/device_config.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, device: str, disable_block_bulling: bool = False):
8484

8585
self._total_banks = 48
8686
self._reserved_banks = 4
87-
self._input_granularity = 8
87+
self._input_granularity = {1: 8, 2: 8, 4: 16}
8888
self._accumulator_granularity = {4: 16, 5: 20}
8989
self._lut_reserved = True
9090
elif self._device == "ethos-u55-128":
@@ -96,7 +96,7 @@ def __init__(self, device: str, disable_block_bulling: bool = False):
9696

9797
self._total_banks = 24
9898
self._reserved_banks = 4
99-
self._input_granularity = 4
99+
self._input_granularity = {1: 4, 2: 4, 4: 8}
100100
self._accumulator_granularity = {4: 8, 5: 12}
101101
self._lut_reserved = True
102102
elif self._device == "ethos-u55-64":
@@ -108,7 +108,7 @@ def __init__(self, device: str, disable_block_bulling: bool = False):
108108

109109
self._total_banks = 16
110110
self._reserved_banks = 2
111-
self._input_granularity = 2
111+
self._input_granularity = {1: 2, 2: 2, 4: 4}
112112
self._accumulator_granularity = {4: 4, 5: 8}
113113
self._lut_reserved = False
114114
elif self._device == "ethos-u55-32":
@@ -120,8 +120,8 @@ def __init__(self, device: str, disable_block_bulling: bool = False):
120120

121121
self._total_banks = 16
122122
self._reserved_banks = 2
123-
self._input_granularity = 2
124-
self._accumulator_granularity = {4: 4, 5: 8}
123+
self._input_granularity = {1: 2, 2: 2, 4: 4}
124+
self._accumulator_granularity = {4: 4, 5: 4}
125125
self._lut_reserved = False
126126

127127
def _get_output_cycles(
@@ -448,18 +448,32 @@ def _get_input_banks(self, input_block_shape, input_bytewidth):
448448
input_block_shape.depth * input_bytewidth, 8
449449
)
450450
input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2
451-
input_banks = _round_up(input_banks, self._input_granularity)
451+
input_banks = _round_up(input_banks, self._input_granularity[input_bytewidth])
452452

453453
return input_banks
454454

455-
def _get_accumulator_banks(self, output_block_shape, acc_bytewidth, depth):
456-
acc_depth = _round_up(min(output_block_shape.depth, depth), 8)
455+
def _get_accumulator_banks(self, output_block_shape, acc_bytewidth):
456+
acc_depth = _round_up(output_block_shape.depth, 8)
457457
acc_bytes = output_block_shape.area() * self._align(acc_depth, 8) * acc_bytewidth
458458
acc_banks = _round_up_div(acc_bytes, self._bank_size_bytes) * 2
459459
acc_banks = _round_up(acc_banks, self._accumulator_granularity[acc_bytewidth])
460460

461461
return acc_banks
462462

463+
@staticmethod
464+
def _create_layout_block(nhwc_block_config, layout):
465+
"""A helper function to convert to brick layout"""
466+
if layout == "NHCWB16":
467+
return [
468+
nhwc_block_config[0],
469+
nhwc_block_config[1],
470+
1 + ((nhwc_block_config[3] - 1) // 16),
471+
nhwc_block_config[2],
472+
16,
473+
]
474+
# else it could only be NHWC
475+
return nhwc_block_config
476+
463477
def get_elementwise_block_config(
464478
self,
465479
ifm_propagator: Propagator,
@@ -537,22 +551,22 @@ def get_elementwise_block_config(
537551
# Split the block in half until it fits into SHRAM
538552
max_height, max_width, max_depth = self._max_block_shape.as_list()[1:]
539553
if output_layout == "NHCWB16":
540-
split_order = (a for a in [1, 3, 2])
541-
output_block = [
542-
output_shape[0],
543-
_round_up(min(output_shape[1], max_height), self._micro_block.height),
544-
min(output_shape[2] * output_shape[4], max_depth),
545-
_round_up(min(output_shape[3], max_width), self._micro_block.width),
546-
16,
547-
]
554+
output_height = output_shape[1]
555+
output_width = output_shape[3]
556+
output_channels = output_shape[2] * 16
548557
else:
549-
split_order = (a for a in [1, 2, 3])
550-
output_block = [
551-
output_shape[0],
552-
_round_up(min(output_shape[1], max_height), self._micro_block.height),
553-
_round_up(min(output_shape[2], max_width), self._micro_block.width),
554-
_round_up(min(output_shape[3], max_depth), self._micro_block.depth),
555-
]
558+
output_height = output_shape[1]
559+
output_width = output_shape[2]
560+
output_channels = output_shape[3]
561+
562+
output_nhwc_block = [
563+
1,
564+
_round_up(min(output_height, max_height), self._micro_block.height),
565+
_round_up(min(output_width, max_width), self._micro_block.width),
566+
_round_up(min(output_channels, max_depth), self._micro_block.depth),
567+
]
568+
output_block = self._create_layout_block(output_nhwc_block, output_layout)
569+
split_order = (a for a in [1, 2, 3])
556570
split_axis = next(split_order)
557571

558572
offset = [0] * len(output_block)
@@ -572,7 +586,7 @@ def get_elementwise_block_config(
572586
)
573587
else:
574588
# Unary elementwise
575-
input2_block = _Shape([0, 0, 0, 0])
589+
input2_block = input_block
576590

577591
input_block.round_up(self._input_micro_block)
578592
input2_block.round_up(self._input_micro_block)
@@ -589,15 +603,19 @@ def get_elementwise_block_config(
589603
)
590604
output_cycles *= reduce(lambda a, b: a * b, output_block, 1)
591605
output_cycles = int(math.ceil(output_cycles))
592-
block_config.append(BlockConfig(output_block, output_block, 0, output_cycles))
606+
block_config.append(
607+
BlockConfig(input_block.as_list(), output_block, 0, output_cycles)
608+
)
593609
break
594610

595-
if output_block[split_axis] == self._micro_block.as_list()[split_axis]:
611+
if output_nhwc_block[split_axis] == self._micro_block.as_list()[split_axis]:
596612
split_axis = next(split_order)
597613

598-
output_block[split_axis] = _round_up(
599-
_round_up_div(output_block[split_axis], 2), self._micro_block.as_list()[split_axis]
614+
output_nhwc_block[split_axis] = _round_up(
615+
_round_up_div(output_nhwc_block[split_axis], 2),
616+
self._micro_block.as_list()[split_axis],
600617
)
618+
output_block = self._create_layout_block(output_nhwc_block, output_layout)
601619

602620
return block_config
603621

@@ -739,7 +757,7 @@ def get_valid_block_configs(
739757
height,
740758
1 + ((depth - 1) // 16),
741759
width,
742-
min(16, _round_up(ofm_channels, self._micro_block.depth)),
760+
16,
743761
)
744762
order = [1, 2, 4, 3, 0]
745763
else:
@@ -771,9 +789,7 @@ def get_valid_block_configs(
771789
# Banks required for input block
772790
input_banks = self._get_input_banks(input_block_shape, input_bytewidth)
773791
# Banks required for accumulation
774-
acc_banks = self._get_accumulator_banks(
775-
output_block_shape, acc_bytewidth, depth
776-
)
792+
acc_banks = self._get_accumulator_banks(output_block_shape, acc_bytewidth)
777793

778794
if (input_banks + acc_banks) <= banks_available:
779795
output_cycles = self._get_output_cycles(

python/tvm/relay/backend/contrib/ethosu/te/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def get_layout_transform_matrices(ofm_channels: int) -> Tuple[List[List[float]],
5353
[1, 0, 0, 0, 0, 0],
5454
[0, 1, 0, 0, 0, 0],
5555
[0, 0, 0, 1, 0, 0],
56-
[0, 0, 0, 0, 0, ofm_channels],
56+
# We need to offset only if number of ofm_channels is not divisible by 16
57+
# Moreover, we can't use just the "ofm_channels" as last element because
58+
# the propogation matrices are used to propogate block configs as well.
59+
[0, 0, 16, 0, 0, -(int(ofm_channels % 16 != 0)) * (16 - ofm_channels % 16)],
5760
[0, 0, 0, 0, 0, 1],
5861
]
5962

src/contrib/ethosu/cascader/parts/ethosu.cc

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,34 +70,41 @@ const std::vector<int64_t> EthosuPartNode::GetBytesRead(const std::vector<int>&
7070
return bytes_per_input;
7171
}
7272

73-
const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stripe_config) {
74-
BlockConfig best_block_config;
75-
float best_cost = std::numeric_limits<float>::infinity();
73+
float EthosuPartNode::CalculateCost(const BlockConfig& block_config,
74+
const StripeConfig& output_stripe_config) {
75+
std::vector<int> output_block = block_config->GetOutputBlockShape();
7676
std::vector<int> output_stripe_shape = output_stripe_config->GetShape();
7777
auto input_stripe_configs = CalculateInputStripeConfigs(output_stripe_config);
7878
std::vector<int> input_stripe_shape = input_stripe_configs[0]->GetShape();
7979

80-
for (const auto& block_config : valid_block_configs_) {
81-
std::vector<int> output_block = block_config->GetOutputBlockShape();
80+
std::vector<int64_t> bytes_per_input = GetBytesRead(output_block, output_stripe_shape);
81+
bytes_per_input[0] *= subkernels_;
8282

83-
std::vector<int64_t> bytes_per_input = GetBytesRead(output_block, output_stripe_shape);
84-
bytes_per_input[0] *= subkernels_;
83+
// Calculate bytes read per output element
84+
float cost =
85+
static_cast<float>(bytes_per_input[0] + bytes_per_input[1]) / mul_reduce(output_stripe_shape);
8586

86-
// Calculate bytes read per output element
87-
float relative_cost = static_cast<float>(bytes_per_input[0] + bytes_per_input[1]) /
88-
mul_reduce(output_stripe_shape);
87+
// Single buffering hardware optimization
88+
if (mul_reduce(input_stripe_shape) <= 2 * mul_reduce(block_config->GetInputBlockShape())) {
89+
cost /= 2;
90+
}
91+
return cost;
92+
}
8993

90-
// Single buffering hardware optimization
91-
if (mul_reduce(input_stripe_shape) <= 2 * mul_reduce(block_config->GetInputBlockShape())) {
92-
relative_cost /= 2;
93-
}
94+
const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stripe_config) {
95+
BlockConfig best_block_config = valid_block_configs_[0];
96+
float best_cost = CalculateCost(best_block_config, output_stripe_config);
97+
std::vector<int> output_stripe_shape = output_stripe_config->GetShape();
98+
auto input_stripe_configs = CalculateInputStripeConfigs(output_stripe_config);
99+
std::vector<int> input_stripe_shape = input_stripe_configs[0]->GetShape();
94100

101+
for (const auto& block_config : valid_block_configs_) {
102+
float relative_cost = CalculateCost(block_config, output_stripe_config);
95103
if (relative_cost < best_cost) {
96104
best_block_config = block_config;
97105
best_cost = relative_cost;
98106
}
99107
}
100-
101108
return best_block_config;
102109
}
103110

src/contrib/ethosu/cascader/parts/ethosu.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ class EthosuPartNode : public PartNode {
7575
const std::vector<int64_t> GetBytesRead(const std::vector<int>& block_shape,
7676
const std::vector<int>& full_shape);
7777

78+
/*!
79+
* \brief Get cost heuristic of using a given block config with the associated stripe config
80+
* \param block_config The block config that is being checked for the cost
81+
* \param output_stripe_config The striping configuration associated with the operator
82+
* \return A cost heuristic representative of the choice
83+
*/
84+
float CalculateCost(const BlockConfig& block_config, const StripeConfig& output_stripe_config);
85+
7886
/*! \brief List of block configs that are valid for this part */
7987
std::vector<BlockConfig> valid_block_configs_;
8088
/*! \brief The output volume that is atomically computed */

src/target/source/interface_c.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@ class InterfaceCNode : public runtime::ModuleNode {
167167
code_stream << " * \\param outputs Output tensors for the module \n";
168168
}
169169

170-
if (!devices_.empty()) {
171-
code_stream << " * \\param devices Device context pointers for the module \n";
172-
}
173170
if (!pools_.empty()) {
174171
code_stream << " * \\param workspace_pools Workspace memory pool pointers for the module \n";
175172
}
173+
if (!devices_.empty()) {
174+
code_stream << " * \\param devices Device context pointers for the module \n";
175+
}
176176

177177
code_stream << " */\n"
178178
<< "int32_t " << run_function << "(\n";
@@ -182,12 +182,12 @@ class InterfaceCNode : public runtime::ModuleNode {
182182
call_args_ss << " struct " << inputs_struct << "* inputs,\n";
183183
call_args_ss << " struct " << outputs_struct << "* outputs,\n";
184184
}
185-
if (!devices_.empty()) {
186-
call_args_ss << " struct " << devices_struct << "* devices,\n";
187-
}
188185
if (!pools_.empty()) {
189186
call_args_ss << " struct " << pools_struct << "* workspace_pools,\n";
190187
}
188+
if (!devices_.empty()) {
189+
call_args_ss << " struct " << devices_struct << "* devices,\n";
190+
}
191191
std::string call_args_str = call_args_ss.str();
192192
call_args_str.pop_back();
193193
call_args_str.pop_back();

tests/cpp/target/source/interface_c_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,33 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) {
126126
ASSERT_THAT(header_source, HasSubstr(run_function.str()));
127127
}
128128

129+
TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePoolsAndDevices) {
130+
std::stringstream run_function;
131+
132+
run_function << "/*!\n"
133+
<< " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n"
134+
<< " * \\param inputs Input tensors for the module \n"
135+
<< " * \\param outputs Output tensors for the module \n"
136+
<< " * \\param workspace_pools Workspace memory pool pointers for the module \n"
137+
<< " * \\param devices Device context pointers for the module \n"
138+
<< " */\n"
139+
<< "int32_t tvmgen_ultimate_cat_spotter_run(\n"
140+
<< " struct tvmgen_ultimate_cat_spotter_inputs* inputs,\n"
141+
<< " struct tvmgen_ultimate_cat_spotter_outputs* outputs,\n"
142+
<< " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools,\n"
143+
<< " struct tvmgen_ultimate_cat_spotter_devices* devices\n"
144+
<< ");\n";
145+
146+
PoolInfo pool_info = PoolInfo("my_memory_pool", {});
147+
tir::usmp::AllocatedPoolInfo allocated_pool_info =
148+
tir::usmp::AllocatedPoolInfo(pool_info, 100000);
149+
runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
150+
{allocated_pool_info}, {}, {"device"}, 0);
151+
std::string header_source = test_module->GetSource();
152+
153+
ASSERT_THAT(header_source, HasSubstr(run_function.str()));
154+
}
155+
129156
TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) {
130157
std::stringstream run_function_with_map_functions;
131158

tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,14 @@
166166
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
167167
((1, 4, 4, 96), (1, 4, 6, 4, 16)),
168168
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
169-
((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 10, 1, 6, 4)),
169+
((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 8, 1, 4, 16)),
170170
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
171171
# Depthwise Conv2D
172-
((1, 6, 10, 16), (1, 6, 1, 10, 16)),
173-
((1, 8, 5, 16), (1, 8, 1, 5, 16)),
172+
((1, 6, 10, 16), (1, 4, 1, 12, 16)),
173+
((1, 8, 5, 16), (1, 6, 1, 5, 16)),
174174
# Pooling
175-
((1, 1, 1, 128), (1, 1, 8, 1, 16)),
176-
((1, 9, 6, 16), (1, 9, 1, 6, 16)),
175+
((1, 1, 1, 128), (1, 1, 4, 1, 16)),
176+
((1, 9, 6, 16), (1, 8, 1, 4, 16)),
177177
],
178178
),
179179
(
@@ -184,14 +184,14 @@
184184
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
185185
((1, 4, 4, 96), (1, 4, 6, 4, 16)),
186186
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
187-
((1, 10, 6, 8), (1, 10, 1, 6, 8)),
187+
((1, 10, 6, 8), (1, 8, 1, 4, 16)),
188188
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
189189
# Depthwise Conv2D
190-
((1, 6, 10, 16), (1, 6, 1, 10, 16)),
191-
((1, 8, 5, 16), (1, 8, 1, 5, 16)),
190+
((1, 6, 10, 16), (1, 4, 1, 12, 16)),
191+
((1, 8, 5, 16), (1, 6, 1, 5, 16)),
192192
# Pooling
193-
((1, 1, 1, 128), (1, 1, 8, 1, 16)),
194-
((1, 9, 6, 16), (1, 9, 1, 6, 16)),
193+
((1, 1, 1, 128), (1, 1, 4, 1, 16)),
194+
((1, 9, 6, 16), (1, 8, 1, 4, 16)),
195195
],
196196
),
197197
(
@@ -202,15 +202,15 @@
202202
((1, 5, 8, 16), (1, 5, 1, 8, 16)),
203203
((1, 4, 4, 128), (1, 4, 8, 4, 16)),
204204
((1, 16, 4, 16), (1, 16, 1, 4, 16)),
205-
((1, 8, 12, 8), (1, 8, 1, 12, 8)),
206-
((1, 10, 6, 16), (1, 10, 1, 6, 16)),
205+
((1, 8, 12, 8), (1, 10, 1, 6, 16)),
206+
((1, 10, 6, 16), (1, 10, 1, 6, 16), (1, 6, 1, 6, 16)),
207207
# Depthwise Conv2D
208-
((1, 7, 10, 16), (1, 7, 1, 10, 16), (1, 7, 2, 10, 16)),
209-
((1, 10, 6, 16), (1, 10, 1, 6, 16)),
208+
((1, 7, 10, 16), (1, 7, 1, 10, 16), (1, 6, 1, 10, 16)),
209+
((1, 10, 6, 16), (1, 10, 1, 6, 16), (1, 6, 1, 6, 16)),
210210
# Pooling
211211
# ((1, 1, 2, 16), (1, 1, 1, 2, 16)),
212-
((1, 1, 2, 128), (1, 1, 8, 2, 16)),
213-
((1, 10, 6, 16), (1, 10, 1, 6, 16)),
212+
((1, 1, 2, 128), (1, 1, 4, 2, 16)),
213+
((1, 10, 6, 16), (1, 9, 1, 6, 16)),
214214
],
215215
),
216216
(
@@ -221,14 +221,14 @@
221221
((1, 16, 8, 16), (1, 16, 1, 8, 16)),
222222
((1, 4, 4, 128), (1, 4, 8, 4, 16)),
223223
((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)),
224-
((1, 20, 12, 8), (1, 20, 1, 12, 8)),
224+
((1, 20, 12, 8), (1, 10, 1, 12, 16)),
225225
((1, 12, 10, 16), (1, 12, 1, 10, 16)),
226226
# Depthwise Conv2D
227-
((1, 8, 20, 16), (1, 8, 1, 20, 16), (1, 8, 2, 20, 16)),
228-
((1, 14, 6, 16), (1, 14, 1, 6, 16)),
227+
((1, 8, 20, 16), (1, 6, 1, 20, 16), (1, 6, 2, 20, 16)),
228+
((1, 14, 6, 16), (1, 12, 1, 6, 16)),
229229
# Pooling
230230
# ((1, 2, 2, 16), (1, 2, 1, 2, 16)),
231-
((1, 2, 2, 128), (1, 2, 8, 2, 16)),
231+
((1, 2, 2, 128), (1, 2, 6, 2, 16)),
232232
((1, 10, 12, 16), (1, 10, 1, 12, 16)),
233233
],
234234
),

0 commit comments

Comments
 (0)