Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] n-th order gradient test support. #15611

Merged
94 changes: 79 additions & 15 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ def sin(x):
def grad_grad_op(x):
return -nd.sin(x)

def grad_grad_grad_op(x):
return -nd.cos(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sin, grad_grad_op)
# TODO(kshitij12345): Remove
check_nth_order_unary(array, sin,
[grad_grad_op, grad_grad_grad_op], [2, 3])


@with_seed()
Expand All @@ -45,10 +51,16 @@ def cos(x):
def grad_grad_op(x):
return -nd.cos(x)

def grad_grad_grad_op(x):
return nd.sin(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cos, grad_grad_op)
# TODO(kshitij12345): Remove
check_nth_order_unary(array, cos,
[grad_grad_op, grad_grad_grad_op], [2, 3])


@with_seed()
Expand Down Expand Up @@ -150,13 +162,18 @@ def test_log():
def log(x):
return nd.log(x)

def grad_op(x):
return 1/x

def grad_grad_op(x):
return -1/(x**2)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log, grad_grad_op)
# TODO(kshitij12345): Remove
check_nth_order_unary(array, log, [grad_op, grad_grad_op], [1, 2])


@with_seed()
Expand Down Expand Up @@ -259,6 +276,9 @@ def grad_grad_op(x):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sigmoid, grad_grad_op)
# TODO(kshitij12345): Remove
check_nth_order_unary(array, sigmoid, [grad_op, grad_grad_op], [1, 2])
check_nth_order_unary(array, sigmoid, grad_grad_op, 2)


@with_seed()
Expand Down Expand Up @@ -302,28 +322,72 @@ def grad_grad_op(x):


def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
check_nth_order_unary(x, op, grad_grad_op, 2, rtol, atol)


def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None):
"""Assert n-th order autograd gradient against expected gradient.

Multiple order of gradients can be checked by passing list of
function computing the particular order gradient and passing the
corresponding list of order.

Note
----
1. Orders should always be monotonically increasing.
2. Elements of grads_ops should correspond to elements of orders
i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with
orders = [1, 3]

Parameters
----------
x : mxnet.NDArray
Input Array.
op : Callable
Operation to perform on Input Array.
grad_ops : Callable or List of Callable
Function to compute and assert gradient of given order.
orders : int or List of int
Order/s to assert expected and computed gradients.

Returns
-------
None

"""
if isinstance(orders, int):
orders = [orders]
grad_ops = [grad_ops]

x = nd.array(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()

# Manual head_grads.
y_grad = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)
order = max(orders)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If orders is monotonically increasing, should this just be orders[-1]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That will work as well. But I felt max(orders) stated the intent more clearly.

expected_grads = [grad_op(x) for grad_op in grad_ops]
computed_grads = []
head_grads = []

# Perform compute.
with autograd.record():
y = op(x)
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
create_graph=True, retain_graph=True)[0]
x_grad.backward(head_grad_grads)

# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()

# Validate the gradients.
assert_almost_equal(expected_grad_grad,
x.grad.asnumpy(), rtol=rtol, atol=atol)
for current_order in range(1, order+1):
head_grad = nd.random.normal(shape=x.shape)
y = autograd.grad(heads=y, variables=x, head_grads=head_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better name instead of mutating y?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option I can think of is, at the start of for loop we'll have computed_grad = y (which is deceiving) and replace y by computed_grad in the for loop.
Up for suggestions.

create_graph=True, retain_graph=True)[0]
if current_order in orders:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If current_order is not in orders we might have problem zipping? Is there a case where you wou want 1st and 3rd order but not second?

Copy link
Contributor Author

@kshitij12345 kshitij12345 Jul 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there would be an issue in that case (will confirm it later though).
If current_order is not in orders case is checked here,
https://github.com/apache/incubator-mxnet/blob/1f74614391e182e299e2fdcce1036515c4e5fb4f/tests/python/unittest/test_higher_order_grad.py#L41-L42
where first order is not asserted (as per the arguments).

The main thing is that elements in orders should be monotonically increasing and they should correspond to elements of the grad_ops.
We can also use a dictionary {order: corresponding grad op, ... } which removes the above requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have confirmed, that

check_nth_order_unary(array, sin, [grad_op, grad_grad_grad_op], [1, 3])

following works.

computed_grads.append(y)
head_grads.append(head_grad)

# Validate all the gradients.
for order, grad, computed_grad in \
zip(orders, expected_grads, computed_grads):
# Compute expected values.
expected_grad = grad.asnumpy()
for head_grad in head_grads[:order]:
expected_grad *= head_grad.asnumpy()

assert_almost_equal(
expected_grad, computed_grad.asnumpy(), rtol=rtol, atol=atol)


if __name__ == '__main__':
Expand Down