-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[BYOC][ACL] Fix list is not supported as an input node #10801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5845760
[BYOC][ACL] Fix list is not supported as an input node
937e91a
fix clang lint error
b47f26a
fix compile warnning
4089f17
fix python module import error
1d3aebc
rename concatenate test file
d89b7d4
fix always MakeACLTensor with same eid 0
8c87d45
do not offload concat default
6c03b19
fix concattnate test failure
9d36318
fix test failure
ca28694
fix lint error
9bb6421
fix lint
56eb714
remove global var offload_concat
b219357
support concatenate with pattern table mechanism
4603c12
disable pylint dangerous-default-value warning
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
tests/python/contrib/test_arm_compute_lib/test_concatenate.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Arm Compute Library integration space_to_batch_nd tests.""" | ||
DzAvril marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| import numpy as np | ||
|
|
||
| import tvm | ||
| from tvm import relay | ||
| from tvm import testing | ||
|
|
||
| from test_arm_compute_lib.infrastructure import ( | ||
| skip_runtime_test, | ||
| skip_codegen_test, | ||
| build_and_run, | ||
| verify, | ||
| verify_codegen, | ||
| ) | ||
| from test_arm_compute_lib.infrastructure import Device | ||
|
|
||
|
|
||
| def _get_model(input_shape_a, input_shape_b, input_shape_c, axis, dtype, var_names): | ||
| """Return a model and any parameters it may have.""" | ||
| a = relay.var(next(var_names), shape=input_shape_a, dtype=dtype) | ||
| b = relay.var(next(var_names), shape=input_shape_b, dtype=dtype) | ||
| c = relay.var(next(var_names), shape=input_shape_c, dtype=dtype) | ||
| out = relay.concatenate([a, b, c], axis) | ||
| return out | ||
|
|
||
|
|
||
| def _get_expected_codegen(input_shape_a, input_shape_b, input_shape_c, axis, dtype): | ||
| node = { | ||
| "op": "kernel", | ||
| "name": "concatenate", | ||
| "inputs": [ | ||
| [0, 0, 0], | ||
| [0, 1, 0], | ||
| [0, 2, 0], | ||
| ], | ||
| "attrs": { | ||
| "num_outputs": "1", | ||
| "num_inputs": "3", | ||
| "dtype": [[dtype]], | ||
| "axis": [[str(axis)]], | ||
| "shape": [[[3, 234, 234, 256]]], | ||
| }, | ||
| } | ||
|
|
||
| input = { | ||
| "op": "input", | ||
| "name": "", | ||
| "attrs": { | ||
| "shape": [[input_shape_a, input_shape_b, input_shape_c]], | ||
| "dtype": [[dtype, dtype, dtype]], | ||
| }, | ||
| } | ||
|
|
||
| return [input, node] | ||
|
|
||
|
|
||
| def test_concatenate(): | ||
| Device.load("test_config.json") | ||
|
|
||
| if skip_runtime_test(): | ||
| return | ||
|
|
||
| device = Device() | ||
| np.random.seed(0) | ||
|
|
||
| for input_shape_a, input_shape_b, input_shape_c, axis in [ | ||
| ([1, 234, 234, 256], [1, 234, 234, 256], [1, 234, 234, 256], 0), | ||
DzAvril marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ]: | ||
| dtype = "int32" | ||
DzAvril marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| outputs = [] | ||
| inputs = { | ||
| "a": tvm.nd.array(np.random.randn(*input_shape_a).astype(dtype)), | ||
| "b": tvm.nd.array(np.random.randn(*input_shape_b).astype(dtype)), | ||
| "c": tvm.nd.array(np.random.randn(*input_shape_c).astype(dtype)), | ||
| } | ||
| func = _get_model( | ||
| inputs["a"].shape, inputs["b"].shape, inputs["c"].shape, axis, dtype, iter(inputs) | ||
| ) | ||
| for acl in [False, True]: | ||
| outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0]) | ||
|
|
||
| config = { | ||
| "input_shape_a": input_shape_a, | ||
| "input_shape_b": input_shape_b, | ||
| "input_shape_c": input_shape_c, | ||
| "axis": 0, | ||
| "dtype": dtype, | ||
| } | ||
| verify(outputs, atol=1e-7, rtol=1e-7, config=config) | ||
|
|
||
|
|
||
| def test_codegen_concatenate(): | ||
| if skip_codegen_test(): | ||
| return | ||
| shape_a = [1, 234, 234, 256] | ||
| shape_b = [1, 234, 234, 256] | ||
| shape_c = [1, 234, 234, 256] | ||
| axis = 0 | ||
| inputs = {"a", "b", "c"} | ||
| for dtype in ["float32"]: | ||
| args = (shape_a, shape_b, shape_c, axis, dtype) | ||
| func = _get_model(*args, iter(inputs)) | ||
| exp_codegen = _get_expected_codegen(*args) | ||
| verify_codegen(func, exp_codegen, 1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_concatenate() | ||
| test_codegen_concatenate() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.