forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_onnxscript_runtime.py
127 lines (103 loc) · 4.41 KB
/
test_onnxscript_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import List
import onnx_test_common
import onnxscript
import torch
from onnxscript.onnx_types import FLOAT
from torch.onnx._internal import jit_utils
from torch.testing._internal import common_utils
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_selu_from_onnxscript_example(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
from onnxscript.onnx_opset import opset15 as op
# TODO(titaiwang): make an official domain for onnxscript usage
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
# TODO: onnx/ort doesn't support default values for now
# move this when they do
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=self.opset_version,
)
self.run_test(model, x)
def test_layer_norm(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
class N(torch.nn.Module):
def __init__(self, prob):
super().__init__()
self.dropout = torch.nn.Dropout(prob)
def forward(self, x):
return self.dropout(x)
class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
)
self.celu1 = torch.nn.CELU(1.0)
self.celu2 = torch.nn.CELU(2.0)
self.dropout = N(0.5)
def forward(self, x, y, z):
res1 = self.celu1(x)
res2 = self.celu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.dropout(z)
model = M(3)
from onnxscript.onnx_opset import opset15 as op
custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
@onnxscript.script(custom_opset)
def layer_norm(
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
):
mean = op.ReduceMean(X, axes=axes)
D = X - mean # op.Sub(X, mean)
DD = D * D # op.Mul(D, D)
var = op.ReduceMean(DD, axes=axes)
vareps = var + eps # op.Add(var, eps)
stddev = op.Sqrt(vareps)
invstddev = op.Reciprocal(stddev)
normalized = D * invstddev # op.Mul(D, invstddev)
normalizedw = op.CastLike(
normalized, weight
) # Type issue if missing this Op
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
return normalizedscaled + bias
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def custom_layer_norm(
g, input, normalized_shape, weight, bias, eps, cudnn_enable
):
# TODO: move the comprehension into local function once it's supported by onnxscript
axes = [-i for i in range(len(normalized_shape), 0, -1)]
return g.onnxscript_op(
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
).setType(input.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::layer_norm",
symbolic_fn=custom_layer_norm,
opset_version=self.opset_version,
)
self.run_test(model, (x, y, z))
if __name__ == "__main__":
common_utils.run_tests()