|
17 | 17 | - Transformer |
18 | 18 | """ |
19 | 19 |
|
| 20 | +from typing import Callable |
| 21 | + |
20 | 22 | import torch |
21 | 23 | from executorch.backends.arm.test.common import parametrize |
22 | 24 | from executorch.backends.arm.test.tester.test_pipeline import ( |
23 | 25 | TosaPipelineFP, |
24 | 26 | TosaPipelineINT, |
25 | 27 | ) |
26 | 28 |
|
| 29 | + |
| 30 | +def make_module_wrapper( |
| 31 | + name: str, module_factory: Callable[[], torch.nn.Module] |
| 32 | +) -> torch.nn.Module: |
| 33 | + class ModuleWrapper(torch.nn.Module): |
| 34 | + def __init__(self): |
| 35 | + super().__init__() |
| 36 | + self._module = module_factory() |
| 37 | + |
| 38 | + def forward(self, *args, **kwargs): |
| 39 | + return self._module(*args, **kwargs) |
| 40 | + |
| 41 | + ModuleWrapper.__name__ = name |
| 42 | + ModuleWrapper.__qualname__ = name |
| 43 | + return ModuleWrapper() |
| 44 | + |
| 45 | + |
27 | 46 | example_input = torch.rand(1, 6, 16, 16) |
28 | 47 |
|
29 | 48 | module_tests = [ |
30 | | - (torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)), |
31 | | - (torch.nn.LeakyReLU(), (example_input,)), |
32 | | - (torch.nn.BatchNorm1d(16), (torch.rand(6, 16, 16),)), |
33 | | - (torch.nn.AdaptiveAvgPool2d((12, 12)), (example_input,)), |
34 | | - (torch.nn.ConvTranspose2d(6, 3, 2), (example_input,)), |
35 | | - (torch.nn.GRU(10, 20, 2), (torch.randn(5, 3, 10), torch.randn(2, 3, 20))), |
36 | | - (torch.nn.GroupNorm(2, 6), (example_input,)), |
37 | | - (torch.nn.InstanceNorm2d(16), (example_input,)), |
38 | | - (torch.nn.PReLU(), (example_input,)), |
39 | 49 | ( |
40 | | - torch.nn.Transformer( |
41 | | - d_model=64, |
42 | | - nhead=1, |
43 | | - num_encoder_layers=1, |
44 | | - num_decoder_layers=1, |
45 | | - dtype=torch.float32, |
| 50 | + make_module_wrapper( |
| 51 | + "EmbeddingModule", |
| 52 | + lambda: torch.nn.Embedding(10, 10), |
| 53 | + ), |
| 54 | + (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),), |
| 55 | + ), |
| 56 | + ( |
| 57 | + make_module_wrapper("LeakyReLUModule", torch.nn.LeakyReLU), |
| 58 | + (example_input,), |
| 59 | + ), |
| 60 | + ( |
| 61 | + make_module_wrapper("BatchNorm1dModule", lambda: torch.nn.BatchNorm1d(16)), |
| 62 | + (torch.rand(6, 16, 16),), |
| 63 | + ), |
| 64 | + ( |
| 65 | + make_module_wrapper( |
| 66 | + "AdaptiveAvgPool2dModule", |
| 67 | + lambda: torch.nn.AdaptiveAvgPool2d((12, 12)), |
| 68 | + ), |
| 69 | + (example_input,), |
| 70 | + ), |
| 71 | + ( |
| 72 | + make_module_wrapper( |
| 73 | + "ConvTranspose2dModule", lambda: torch.nn.ConvTranspose2d(6, 3, 2) |
| 74 | + ), |
| 75 | + (example_input,), |
| 76 | + ), |
| 77 | + ( |
| 78 | + make_module_wrapper("GRUModule", lambda: torch.nn.GRU(10, 20, 2)), |
| 79 | + (torch.randn(5, 3, 10), torch.randn(2, 3, 20)), |
| 80 | + ), |
| 81 | + ( |
| 82 | + make_module_wrapper("GroupNormModule", lambda: torch.nn.GroupNorm(2, 6)), |
| 83 | + (example_input,), |
| 84 | + ), |
| 85 | + ( |
| 86 | + make_module_wrapper( |
| 87 | + "InstanceNorm2dModule", lambda: torch.nn.InstanceNorm2d(16) |
| 88 | + ), |
| 89 | + (example_input,), |
| 90 | + ), |
| 91 | + ( |
| 92 | + make_module_wrapper("PReLUModule", torch.nn.PReLU), |
| 93 | + (example_input,), |
| 94 | + ), |
| 95 | + ( |
| 96 | + make_module_wrapper( |
| 97 | + "TransformerModule", |
| 98 | + lambda: torch.nn.Transformer( |
| 99 | + d_model=64, |
| 100 | + nhead=1, |
| 101 | + num_encoder_layers=1, |
| 102 | + num_decoder_layers=1, |
| 103 | + dtype=torch.float32, |
| 104 | + ), |
46 | 105 | ), |
47 | 106 | (torch.rand((10, 32, 64)), torch.rand((20, 32, 64))), |
48 | 107 | ), |
@@ -78,9 +137,9 @@ def test_nn_Modules_FP(test_data): |
78 | 137 | "test_data", |
79 | 138 | test_parameters, |
80 | 139 | xfails={ |
81 | | - "GRU": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.", |
82 | | - "PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", |
83 | | - "Transformer": "AssertionError: Output 0 does not match reference output.", |
| 140 | + "GRUModule": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.", |
| 141 | + "PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", |
| 142 | + "TransformerModule": "AssertionError: Output 0 does not match reference output.", |
84 | 143 | }, |
85 | 144 | ) |
86 | 145 | def test_nn_Modules_INT(test_data): |
|
0 commit comments