This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
test_higher_order_grad.py
156 lines (118 loc) · 4.14 KB
/
test_higher_order_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed
@with_seed()
def test_sin():
def sin(x):
return nd.sin(x)
def grad_grad_op(x):
return -nd.sin(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sin, grad_grad_op)
@with_seed()
def test_cos():
def cos(x):
return nd.cos(x)
def grad_grad_op(x):
return -nd.cos(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cos, grad_grad_op)
@with_seed()
def test_relu():
def relu(x):
return nd.relu(x)
def grad_grad_op(x):
return nd.zeros_like(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, relu, grad_grad_op)
@with_seed()
def test_log():
def log(x):
return nd.log(x)
def grad_grad_op(x):
return -1/(x**2)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log, grad_grad_op)
@with_seed()
def test_log2():
def log2(x):
return nd.log2(x)
def grad_grad_op(x):
return -1/((x**2) * math.log(2))
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log2, grad_grad_op)
@with_seed()
def test_log10():
def log10(x):
return nd.log10(x)
def grad_grad_op(x):
return -1/((x**2) * math.log(10))
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log10, grad_grad_op)
@with_seed()
def test_sigmoid():
def sigmoid(x):
return nd.sigmoid(x)
def grad_grad_op(x):
# Actual: f(x) * f'(x)
# Expected: f(x)
return sigmoid(x) * (1 - sigmoid(x)) * sigmoid(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_sigmoid(array, sigmoid, grad_grad_op)
def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
expect_grad_grad = grad_grad_op(x)
x.attach_grad()
with autograd.record():
y = op(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
def check_sigmoid(x, op, grad_grad_op):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()
y_grad = nd.ones_like(x) * nd.random.normal(shape=x.shape)
head_grad_grad = nd.ones_like(x) * nd.random.normal(shape=x.shape)
with autograd.record():
y = op(x)
x_grad = autograd.grad(y, x, y_grad,create_graph=True, retain_graph=True)[0]
# x_grad.backward(head_grad_grad)
x_grad_grad = autograd.grad(x_grad, [x], head_grad_grad,create_graph=False, retain_graph=True)[0]
expected_grad_grad = grad_grad_x * head_grad_grad # * y_grad
assert_almost_equal(expected_grad_grad.asnumpy(), x_grad_grad.asnumpy())
# assert_almost_equal(expected_grad_grad.asnumpy(), x.grad.asnumpy())
if __name__ == '__main__':
import nose
nose.runmodule()