Skip to content

Commit 5f7c5a5

Browse files
committed
[microNPU] enable striping for network tests.
This commit enables the striping for network tests. Currently it requires, storage_rewrite to be run if striping is enabled to produce correct results. Change-Id: I12b976bb77d339771f8b5a554817d192e7c99723
1 parent 1115fd9 commit 5f7c5a5

File tree

3 files changed

+60
-7
lines changed

3 files changed

+60
-7
lines changed

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,15 @@ def lower_ethosu(sch, args, const_dict, name="main"):
9191
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
9292
mod = ethosu_passes.HoistAllocates()(mod)
9393
mod = ethosu_passes.CopyComputeReordering()(mod)
94+
95+
# When striping is enabled and if storage_rewrite is not run
96+
# the striping results in incorrect code generation. This needs
97+
# further investigation. Until such a time that is fixed, disable_storage_rewrite
98+
# user directive will be overridden if striping is enabled.
9499
disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False)
95-
if not disable_storage_rewrite:
100+
if not disable_storage_rewrite or util.is_striping_enabled():
96101
mod = tvm.tir.transform.StorageRewrite()(mod)
102+
97103
mod = tvm.tir.transform.RemoveNoOp()(mod)
98104
mod = ethosu_passes.AnnotateAllocates()(mod)
99105
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)

tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,13 @@ def tf_graph(x):
158158
assert workspace_size_cascader_enabled_striping_enabled == expected_ws_size_with_striping
159159

160160

161-
# TODO(ekalda): Fix a bug in the block config selection that selects block config that is too large
162-
# for the smaller accelerators
163161
@pytest.mark.parametrize(
164162
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
165163
[
166-
("ethos-u55-256", 180288, 15200),
167-
("ethos-u55-128", 180288, 15200),
168-
("ethos-u55-64", 180288, 14432),
169-
("ethos-u55-32", 180272, 14416),
164+
("ethos-u55-256", 180288, 15312),
165+
("ethos-u55-128", 180288, 15312),
166+
("ethos-u55-64", 180288, 14544),
167+
("ethos-u55-32", 180272, 14544),
170168
],
171169
)
172170
def test_depthwise2d_conv2d_pooling(

tests/python/contrib/test_ethosu/test_networks.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,54 @@ def test_networks_with_usmp_and_cascader_wo_striping(accel_type, model_url, work
139139
assert allocated_pool_info.allocated_size == workspace_size
140140

141141

142+
@pytest.mark.parametrize(
143+
"accel_type, model_url, workspace_size",
144+
[
145+
("ethos-u55-256", MOBILENET_V1_URL, 1005440),
146+
("ethos-u55-256", MOBILENET_V2_URL, 1162368),
147+
],
148+
)
149+
def test_networks_with_usmp_and_cascader_with_striping(accel_type, model_url, workspace_size):
150+
np.random.seed(23)
151+
152+
pool_name = "my_memory_pool"
153+
host_target = tvm.target.Target("c")
154+
ethosu_target = tvm.target.Target("ethos-u")
155+
workspace_pools = WorkspaceMemoryPools(
156+
[
157+
WorkspacePoolInfo(
158+
pool_name,
159+
[host_target, ethosu_target],
160+
PoolInfoProperties(
161+
size_hint_bytes=1200000,
162+
read_bandwidth_bytes_per_cycle=16,
163+
write_bandwidth_bytes_per_cycle=16,
164+
target_burst_bytes={ethosu_target: 1},
165+
),
166+
)
167+
]
168+
)
169+
tflite_model_buf = infra.get_tflite_model(model_url)
170+
input_data, output_data = infra.generate_ref_data_tflite(tflite_model_buf)
171+
mod, params = convert_to_relay(tflite_model_buf)
172+
mod = partition_for_ethosu(mod, params)
173+
test_runner = infra.create_test_runner(
174+
accel_type,
175+
enable_usmp=True,
176+
enable_cascader=True,
177+
enable_striping=True,
178+
workspace_pools=workspace_pools,
179+
)
180+
compiled_models = infra.build_source(
181+
mod, input_data, output_data, test_runner, workspace_pools=workspace_pools
182+
)
183+
infra.verify_source(compiled_models, test_runner)
184+
185+
allocated_pool_info = list(
186+
dict(compiled_models[0].executor_factory.executor_codegen_metadata.pool_inputs).values()
187+
)[0]
188+
assert allocated_pool_info.allocated_size == workspace_size
189+
190+
142191
if __name__ == "__main__":
143192
pytest.main([__file__])

0 commit comments

Comments
 (0)