-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] n-th order gradient test support. #15611
Changes from all commits
1b5d96a
1f74614
4e5c4cc
5523d4a
daa77f1
f7cd885
dccc2e8
c1c14d2
82174a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,10 +31,16 @@ def sin(x): | |
def grad_grad_op(x): | ||
return -nd.sin(x) | ||
|
||
def grad_grad_grad_op(x): | ||
return -nd.cos(x) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, sin, grad_grad_op) | ||
# TODO(kshitij12345): Remove | ||
check_nth_order_unary(array, sin, | ||
[grad_grad_op, grad_grad_grad_op], [2, 3]) | ||
|
||
|
||
@with_seed() | ||
|
@@ -45,10 +51,16 @@ def cos(x): | |
def grad_grad_op(x): | ||
return -nd.cos(x) | ||
|
||
def grad_grad_grad_op(x): | ||
return nd.sin(x) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, cos, grad_grad_op) | ||
# TODO(kshitij12345): Remove | ||
check_nth_order_unary(array, cos, | ||
[grad_grad_op, grad_grad_grad_op], [2, 3]) | ||
|
||
|
||
@with_seed() | ||
|
@@ -150,13 +162,18 @@ def test_log(): | |
def log(x): | ||
return nd.log(x) | ||
|
||
def grad_op(x): | ||
return 1/x | ||
|
||
def grad_grad_op(x): | ||
return -1/(x**2) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, log, grad_grad_op) | ||
# TODO(kshitij12345): Remove | ||
check_nth_order_unary(array, log, [grad_op, grad_grad_op], [1, 2]) | ||
|
||
|
||
@with_seed() | ||
|
@@ -259,6 +276,9 @@ def grad_grad_op(x): | |
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, sigmoid, grad_grad_op) | ||
# TODO(kshitij12345): Remove | ||
check_nth_order_unary(array, sigmoid, [grad_op, grad_grad_op], [1, 2]) | ||
check_nth_order_unary(array, sigmoid, grad_grad_op, 2) | ||
|
||
|
||
@with_seed() | ||
|
@@ -302,28 +322,77 @@ def grad_grad_op(x): | |
|
||
|
||
def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None): | ||
check_nth_order_unary(x, op, grad_grad_op, 2, rtol, atol) | ||
|
||
|
||
def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None): | ||
"""Assert n-th order autograd gradient against expected gradient. | ||
|
||
Multiple order of gradients can be checked by passing list of | ||
function computing the particular order gradient and passing the | ||
corresponding list of order. | ||
|
||
Note | ||
---- | ||
1. Orders should always be monotonically increasing. | ||
2. Elements of grads_ops should correspond to elements of orders | ||
i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with | ||
orders = [1, 3] | ||
|
||
Parameters | ||
---------- | ||
x : mxnet.NDArray | ||
Input Array. | ||
op : Callable | ||
Operation to perform on Input Array. | ||
grad_ops : Callable or List of Callable | ||
Function to compute and assert gradient of given order. | ||
orders : int or List of int | ||
Order/s to assert expected and computed gradients. | ||
|
||
Returns | ||
------- | ||
None | ||
|
||
""" | ||
if isinstance(orders, int): | ||
orders = [orders] | ||
grad_ops = [grad_ops] | ||
|
||
assert all(i < j for i, j in zip(orders[0:-1], orders[1:])), \ | ||
"orders should be monotonically increasing" | ||
assert len(set(orders)) == len(orders), \ | ||
"orders should have unique elements" | ||
highest_order = max(orders) | ||
|
||
x = nd.array(x) | ||
grad_grad_x = grad_grad_op(x) | ||
x.attach_grad() | ||
|
||
# Manual head_grads. | ||
y_grad = nd.random.normal(shape=x.shape) | ||
head_grad_grads = nd.random.normal(shape=x.shape) | ||
expected_grads = [grad_op(x) for grad_op in grad_ops] | ||
computed_grads = [] | ||
head_grads = [] | ||
|
||
# Perform compute. | ||
with autograd.record(): | ||
y = op(x) | ||
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, | ||
create_graph=True, retain_graph=True)[0] | ||
x_grad.backward(head_grad_grads) | ||
|
||
# Compute expected values. | ||
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ | ||
y_grad.asnumpy() | ||
|
||
# Validate the gradients. | ||
assert_almost_equal(expected_grad_grad, | ||
x.grad.asnumpy(), rtol=rtol, atol=atol) | ||
for current_order in range(1, highest_order+1): | ||
head_grad = nd.random.normal(shape=x.shape) | ||
y = autograd.grad(heads=y, variables=x, head_grads=head_grad, | ||
create_graph=True, retain_graph=True)[0] | ||
if current_order in orders: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If current_order is not in orders we might have problem zipping? Is there a case where you wou want 1st and 3rd order but not second? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there would be an issue in that case (will confirm it later though). The main thing is that elements in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have confirmed, that
following works. |
||
computed_grads.append(y) | ||
head_grads.append(head_grad) | ||
|
||
# Validate all the gradients. | ||
for order, grad, computed_grad in \ | ||
zip(orders, expected_grads, computed_grads): | ||
# Compute expected values. | ||
expected_grad = grad.asnumpy() | ||
for head_grad in head_grads[:order]: | ||
expected_grad *= head_grad.asnumpy() | ||
|
||
assert_almost_equal( | ||
expected_grad, computed_grad.asnumpy(), rtol=rtol, atol=atol) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better name instead of mutating y?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One option I can think of is, at the start of for loop we'll have
computed_grad = y
(which is deceiving) and replacey
bycomputed_grad
in the for loop.Up for suggestions.