-
Notifications
You must be signed in to change notification settings - Fork 16
/
test_trace.py
63 lines (53 loc) · 2.27 KB
/
test_trace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations
import unittest
import paddle
import paddlefx
class TestFx(unittest.TestCase):
def setUp(self):
super().setUp()
self.models_to_track = [
(paddle.vision.models.resnet18(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.alexnet(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.densenet121(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.googlenet(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.inception_v3(), paddle.randn([2, 3, 299, 299])),
(paddle.vision.models.mobilenet_v2(), paddle.randn([2, 3, 224, 224])),
(
paddle.vision.models.shufflenet_v2_swish(),
paddle.randn([2, 3, 224, 224]),
),
(paddle.vision.models.squeezenet1_0(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.vgg11(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.wide_resnet101_2(), paddle.randn([2, 3, 224, 224])),
]
def tearDown(self):
super().tearDown()
def test_trace(self):
for model, input_example in self.models_to_track:
print(f"tracing model: {type(model).__name__}")
traced_model = paddlefx.symbolic_trace(model)
paddle.seed(1234)
orig_output = model(input_example)
paddle.seed(1234)
traced_output = traced_model(input_example)
# some nets, e.g.: googlenet, return a list of tensors
orig_ret_list = (
list(orig_output)
if isinstance(orig_output, (list, tuple))
else [orig_output]
)
traced_ret_list = (
[*traced_output]
if isinstance(traced_output, (list, tuple))
else [traced_output]
)
self.assertEqual(
len(orig_ret_list),
len(traced_ret_list),
f"model: {type(model).__name__} failed",
)
for i, o in enumerate(traced_ret_list):
self.assertTrue(
paddle.allclose(orig_ret_list[i], traced_ret_list[i]),
f"model: {type(model).__name__} failed",
)