diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index a758775a09ba..9c758c8467e3 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -31,10 +31,16 @@ def sin(x): def grad_grad_op(x): return -nd.sin(x) + def grad_grad_grad_op(x): + return -nd.cos(x) + for dim in range(1, 5): shape = rand_shape_nd(dim) array = random_arrays(shape) check_second_order_unary(array, sin, grad_grad_op) + # TODO(kshitij12345): Remove + check_nth_order_unary(array, sin, + [grad_grad_op, grad_grad_grad_op], [2, 3]) @with_seed() @@ -45,10 +51,16 @@ def cos(x): def grad_grad_op(x): return -nd.cos(x) + def grad_grad_grad_op(x): + return nd.sin(x) + for dim in range(1, 5): shape = rand_shape_nd(dim) array = random_arrays(shape) check_second_order_unary(array, cos, grad_grad_op) + # TODO(kshitij12345): Remove + check_nth_order_unary(array, cos, + [grad_grad_op, grad_grad_grad_op], [2, 3]) @with_seed() @@ -150,6 +162,9 @@ def test_log(): def log(x): return nd.log(x) + def grad_op(x): + return 1/x + def grad_grad_op(x): return -1/(x**2) @@ -157,6 +172,8 @@ def grad_grad_op(x): shape = rand_shape_nd(dim) array = random_arrays(shape) check_second_order_unary(array, log, grad_grad_op) + # TODO(kshitij12345): Remove + check_nth_order_unary(array, log, [grad_op, grad_grad_op], [1, 2]) @with_seed() @@ -259,6 +276,9 @@ def grad_grad_op(x): shape = rand_shape_nd(dim) array = random_arrays(shape) check_second_order_unary(array, sigmoid, grad_grad_op) + # TODO(kshitij12345): Remove + check_nth_order_unary(array, sigmoid, [grad_op, grad_grad_op], [1, 2]) + check_nth_order_unary(array, sigmoid, grad_grad_op, 2) @with_seed() @@ -302,28 +322,77 @@ def grad_grad_op(x): def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None): + check_nth_order_unary(x, op, grad_grad_op, 2, rtol, atol) + + +def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None): + """Assert n-th order autograd gradient against expected gradient. + + Multiple order of gradients can be checked by passing list of + function computing the particular order gradient and passing the + corresponding list of order. + + Note + ---- + 1. Orders should always be monotonically increasing. + 2. Elements of grads_ops should correspond to elements of orders + i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with + orders = [1, 3] + + Parameters + ---------- + x : mxnet.NDArray + Input Array. + op : Callable + Operation to perform on Input Array. + grad_ops : Callable or List of Callable + Function to compute and assert gradient of given order. + orders : int or List of int + Order/s to assert expected and computed gradients. + + Returns + ------- + None + + """ + if isinstance(orders, int): + orders = [orders] + grad_ops = [grad_ops] + + assert all(i < j for i, j in zip(orders[0:-1], orders[1:])), \ + "orders should be monotonically increasing" + assert len(set(orders)) == len(orders), \ + "orders should have unique elements" + highest_order = max(orders) + x = nd.array(x) - grad_grad_x = grad_grad_op(x) x.attach_grad() - # Manual head_grads. - y_grad = nd.random.normal(shape=x.shape) - head_grad_grads = nd.random.normal(shape=x.shape) + expected_grads = [grad_op(x) for grad_op in grad_ops] + computed_grads = [] + head_grads = [] # Perform compute. with autograd.record(): y = op(x) - x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, - create_graph=True, retain_graph=True)[0] - x_grad.backward(head_grad_grads) - - # Compute expected values. - expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ - y_grad.asnumpy() - - # Validate the gradients. - assert_almost_equal(expected_grad_grad, - x.grad.asnumpy(), rtol=rtol, atol=atol) + for current_order in range(1, highest_order+1): + head_grad = nd.random.normal(shape=x.shape) + y = autograd.grad(heads=y, variables=x, head_grads=head_grad, + create_graph=True, retain_graph=True)[0] + if current_order in orders: + computed_grads.append(y) + head_grads.append(head_grad) + + # Validate all the gradients. + for order, grad, computed_grad in \ + zip(orders, expected_grads, computed_grads): + # Compute expected values. + expected_grad = grad.asnumpy() + for head_grad in head_grads[:order]: + expected_grad *= head_grad.asnumpy() + + assert_almost_equal( + expected_grad, computed_grad.asnumpy(), rtol=rtol, atol=atol) if __name__ == '__main__':