Skip to content

Commit d563e4c

Browse files
authored
Merge pull request pytorch#7 from cavusmustafa/openvino_backend_unit_tests
Initial unit tests for OpenVINO backend
2 parents 379937e + ecbe5e2 commit d563e4c

File tree

14 files changed

+756
-2
lines changed

14 files changed

+756
-2
lines changed

backends/openvino/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def preprocess(
3838
for spec in module_compile_spec:
3939
compile_options[spec.key] = spec.value.decode()
4040

41-
compiled = openvino_compile(edge_program.module(), *args, options=compile_options)
41+
compiled = openvino_compile(edge_program.module(), *args, options=compile_options, executorch=True)
4242
model_bytes = compiled.export_model()
4343

4444
return PreprocessResult(processed_bytes=model_bytes)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import subprocess
3+
import tempfile
4+
import unittest
5+
6+
import numpy as np
7+
import torch
8+
import executorch
9+
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
10+
from executorch.exir.backend.backend_details import CompileSpec
11+
from torch.export import export, ExportedProgram
12+
from executorch.exir import EdgeProgramManager, to_edge
13+
from executorch.backends.openvino.preprocess import OpenvinoBackend
14+
15+
16+
class BaseOpenvinoOpTest(unittest.TestCase):
17+
device = "CPU"
18+
build_folder = ""
19+
20+
atol = 1e-1
21+
rtol = 1e-1
22+
23+
def execute_layer_test(
24+
self,
25+
module: torch.nn.Module,
26+
sample_inputs: tuple[torch.Tensor],
27+
expected_partitions: int = 1,
28+
assert_output_equal: bool = True,
29+
):
30+
31+
module = module.eval()
32+
# Export to aten dialect using torch.export
33+
aten_dialect: ExportedProgram = export(module, sample_inputs)
34+
35+
# Convert to edge dialect
36+
edge_program: EdgeProgramManager = to_edge(aten_dialect)
37+
to_be_lowered_module = edge_program.exported_program()
38+
39+
# Lower the module to the backend with a custom partitioner
40+
compile_spec = [CompileSpec("device", self.device.encode())]
41+
lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec))
42+
43+
# Apply backend-specific passes
44+
exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig())
45+
46+
# Check if the number of partitions created matches the expected number of partitions
47+
self.assertEqual(
48+
len(exec_prog.executorch_program.execution_plan[0].delegates),
49+
expected_partitions,
50+
)
51+
# Check if the individual partitions are assigned to Openvino backend
52+
for i in range(expected_partitions):
53+
self.assertEqual(
54+
exec_prog.executorch_program.execution_plan[0].delegates[i].id,
55+
OpenvinoBackend.__name__,
56+
)
57+
58+
# Execute the model and compare the outputs with the reference outputs
59+
if (assert_output_equal):
60+
with tempfile.TemporaryDirectory() as tmp_dir:
61+
input_list = ""
62+
for idx, _ in enumerate(sample_inputs):
63+
input_name = f"input_0_{idx}.raw"
64+
input_list += input_name + " "
65+
input_list = input_list.strip() + "\n"
66+
67+
output_dir = f"{tmp_dir}/outputs"
68+
69+
# Execute the module in eager mode to calculate the reference outputs
70+
ref_output = module(*sample_inputs)
71+
if isinstance(ref_output, torch.Tensor):
72+
ref_output = [ref_output,]
73+
74+
# Serialize the executorch model and save into a temporary file
75+
pte_fname = f"{tmp_dir}/openvino_executorch_test.pte"
76+
with open(pte_fname, "wb") as file:
77+
exec_prog.write_to_file(file)
78+
79+
# Save inputs into a temporary file
80+
self.generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list)
81+
self.make_output_dir(output_dir)
82+
83+
# Start a subprocess to execute model with openvino_executor_runner
84+
cmd = [
85+
f"{self.build_folder}/examples/openvino/openvino_executor_runner",
86+
"--model_path",
87+
pte_fname,
88+
"--input_list_path",
89+
f"{tmp_dir}/input_list.txt",
90+
"--output_folder_path",
91+
output_dir,
92+
]
93+
94+
env = dict(os.environ)
95+
proc = subprocess.run(
96+
cmd,
97+
stdout=subprocess.PIPE,
98+
stderr=subprocess.STDOUT,
99+
env=env,
100+
cwd=tmp_dir,
101+
)
102+
103+
stdout_str = proc.stdout.decode('utf-8')
104+
105+
# Check if execution completed successfully
106+
self.assertIn("Model executed successfully.", stdout_str)
107+
108+
# Read the outputs from the temporary files
109+
output_dir = f"{tmp_dir}/outputs"
110+
outputs = []
111+
112+
for i, f in enumerate(sorted(os.listdir(output_dir))):
113+
filename = os.path.join(output_dir, f)
114+
output = np.fromfile(filename, dtype=ref_output[i].detach().numpy().dtype)
115+
output = torch.from_numpy(output).reshape(ref_output[i].shape)
116+
outputs.append(output)
117+
118+
# Compare the outputs with the reference outputs
119+
self.assertTrue(len(ref_output) == len(outputs))
120+
for i in range(len(ref_output)):
121+
self.assertTrue(
122+
torch.allclose(
123+
outputs[i], ref_output[i], atol=self.atol, rtol=self.rtol, equal_nan=True
124+
),
125+
msg=f"ref_output:\n{ref_output[i]}\n\ntest_output:\n{outputs[i]}",
126+
)
127+
128+
def generate_inputs(self, dest_path: str, file_name: str, inputs=None, input_list=None):
129+
input_list_file = None
130+
input_files = []
131+
132+
# Prepare input list
133+
if input_list is not None:
134+
input_list_file = f"{dest_path}/{file_name}"
135+
with open(input_list_file, "w") as f:
136+
f.write(input_list)
137+
f.flush()
138+
139+
# Prepare input data
140+
if inputs is not None:
141+
for idx, data in enumerate(inputs):
142+
for i, d in enumerate(data):
143+
file_name = f"{dest_path}/input_{idx}_{i}.raw"
144+
d.detach().numpy().tofile(file_name)
145+
input_files.append(file_name)
146+
147+
return input_list_file, input_files
148+
149+
def make_output_dir(self, path: str):
150+
if os.path.exists(path):
151+
for f in os.listdir(path):
152+
os.remove(os.path.join(path, f))
153+
os.removedirs(path)
154+
os.makedirs(path)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest
2+
import torch
3+
4+
class TestAddOperator(BaseOpenvinoOpTest):
5+
6+
def create_model(self):
7+
class Add(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, x, y):
12+
return torch.add(x, y)
13+
14+
return Add()
15+
16+
def test_add(self):
17+
module = self.create_model()
18+
sample_input = (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3))
19+
self.execute_layer_test(module, sample_input)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest
2+
import torch
3+
4+
class TestAddMMOperator(BaseOpenvinoOpTest):
5+
6+
def create_model(self):
7+
class AddMM(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
self.alpha = 1.
11+
self.beta = 1.
12+
13+
def forward(self, x, y, z):
14+
#return torch.add(x, y)
15+
return torch.addmm(x, y, z, alpha=self.alpha, beta=self.beta)
16+
17+
return AddMM()
18+
19+
def test_addmm(self):
20+
module = self.create_model()
21+
input_x = torch.randn(4,4, dtype=torch.float32)
22+
input_y = torch.randn(4,4, dtype=torch.float32)
23+
input_z = torch.randn(4,4, dtype=torch.float32)
24+
sample_input = (input_x, input_y, input_z)
25+
self.execute_layer_test(module, sample_input)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest
2+
import torch
3+
4+
class TestArangeOperator(BaseOpenvinoOpTest):
5+
6+
def create_model(self, x):
7+
class Arange(torch.nn.Module):
8+
def __init__(self, x):
9+
super().__init__()
10+
self.x = x
11+
12+
def forward(self, y):
13+
return torch.arange(self.x, dtype=torch.float32) + y
14+
15+
return Arange(5)
16+
17+
def test_arange(self):
18+
module = self.create_model(5)
19+
sample_input = (torch.randn(5),)
20+
self.execute_layer_test(module, sample_input)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest
2+
import torch
3+
4+
op_params = [{'weights': True, 'bias': True, 'eps': 1.0 },
5+
{'weights': True, 'bias': True, 'eps': 0.00005 },
6+
{'weights': True, 'bias': True, 'eps': 0.5 },
7+
{'weights': True, 'bias': True, 'eps': 0.042 },
8+
{'weights': True, 'bias': False, 'eps': 1.0 },
9+
{'weights': True, 'bias': False, 'eps': 0.00005 },
10+
{'weights': True, 'bias': False, 'eps': 0.5 },
11+
{'weights': True, 'bias': False, 'eps': 0.042 },
12+
{'weights': False, 'bias': True, 'eps': 1.0 },
13+
{'weights': False, 'bias': True, 'eps': 0.00005 },
14+
{'weights': False, 'bias': True, 'eps': 0.5 },
15+
{'weights': False, 'bias': True, 'eps': 0.042 },
16+
{'weights': False, 'bias': False, 'eps': 1.0 },
17+
{'weights': False, 'bias': False, 'eps': 0.00005 },
18+
{'weights': False, 'bias': False, 'eps': 0.5 },
19+
{'weights': False, 'bias': False, 'eps': 0.042 },
20+
]
21+
22+
23+
class TestBatchNormOperator(BaseOpenvinoOpTest):
24+
25+
def create_model(self, weights, bias, eps):
26+
27+
class BatchNorm(torch.nn.Module):
28+
def __init__(self, weights=True, bias=True, eps=1e-05):
29+
super(BatchNorm, self).__init__()
30+
self.weight = torch.nn.Parameter(torch.randn(6)) if weights else None
31+
self.bias = torch.nn.Parameter(torch.randn(6)) if bias else None
32+
self.running_mean = torch.randn(6)
33+
self.running_var = torch.randn(6)
34+
self.eps = eps
35+
36+
def forward(self, x):
37+
return torch.nn.functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False)
38+
39+
return BatchNorm(weights, bias, eps)
40+
41+
42+
def test_batch_norm(self):
43+
for params in op_params:
44+
with self.subTest(params=params):
45+
module = self.create_model(weights=params['weights'],
46+
bias=params['bias'],
47+
eps=params['eps'])
48+
49+
sample_input = (torch.randn(20, 6, 10),)
50+
51+
self.execute_layer_test(module, sample_input)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest
2+
import torch
3+
4+
d2_params = [{'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [1, 1], 'groups': 1,
5+
'output_padding': [0, 0], 'transposed': True},
6+
{'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [
7+
1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
8+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [
9+
1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True},
10+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [
11+
1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False},
12+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [
13+
1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
14+
{'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [
15+
1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
16+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [
17+
3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
18+
{'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [
19+
3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
20+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [
21+
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
22+
{'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [
23+
0, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
24+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
25+
1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True},
26+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
27+
0, 1], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False},
28+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
29+
1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': True},
30+
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
31+
0, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': False},
32+
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'bias_shape': [1], 'pads': [
33+
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
34+
{'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [
35+
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
36+
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [
37+
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
38+
{'weights_shape': [3, 3, 1, 1], 'strides': [2, 2], 'pads': [
39+
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
40+
{'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [
41+
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False},
42+
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [
43+
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True},
44+
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [
45+
1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1], 'transposed': True},
46+
]
47+
48+
class TestConvolutionOperator(BaseOpenvinoOpTest):
49+
50+
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, transposed, output_padding=0,
51+
bias_shape=None, underscore=False):
52+
53+
bias_dim = 0
54+
55+
class Convolution(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.weight = torch.nn.Parameter(torch.randn(weights_shape))
59+
self.bias_shape = bias_shape
60+
if self.bias_shape is None:
61+
self.bias_shape = weights_shape[bias_dim]
62+
self.bias = torch.nn.Parameter(torch.randn(self.bias_shape)) if bias else None
63+
self.strides = strides
64+
self.pads = pads
65+
self.dilations = dilations
66+
self.groups = groups
67+
self.transposed = transposed
68+
self.output_padding = output_padding
69+
if underscore:
70+
self.forward = self.forward_
71+
72+
def forward(self, x):
73+
return torch.convolution(
74+
x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed,
75+
self.output_padding, self.groups
76+
)
77+
78+
def forward_(self, x):
79+
return torch._convolution(
80+
x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed,
81+
self.output_padding, self.groups, False, False, False, False
82+
)
83+
84+
return Convolution()
85+
86+
def test_convolution(self):
87+
bias_underscore_config = [(False, False), (True, False)]
88+
for bias, underscore in bias_underscore_config:
89+
for params in d2_params:
90+
with self.subTest(params=params, bias=bias, underscore=underscore):
91+
bias_shape = None
92+
if 'bias_shape' in params:
93+
bias_shape = params['bias_shape']
94+
module = self.create_model(weights_shape=params['weights_shape'],
95+
strides=params['strides'],
96+
pads=params['pads'],
97+
dilations=params['dilations'],
98+
groups=params['groups'],
99+
output_padding=params['output_padding'],
100+
transposed=params['transposed'],
101+
bias_shape=bias_shape,
102+
bias=bias,
103+
underscore=underscore)
104+
sample_input = (torch.randn(1, 3, 10, 10),)
105+
self.execute_layer_test(module, sample_input)

0 commit comments

Comments
 (0)