Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refine autograd docs
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Jun 3, 2019
1 parent 3ebd873 commit c073cde
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
38 changes: 38 additions & 0 deletions docs/api/python/autograd/autograd.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ and do some computation. Finally, call `backward()` on the result:
<NDArray 4 @cpu(0)>
```

Gradient recording is enabled during the scope of the `with mx.autograd.record():` statement, then
disabled when we go out of that scope.

It can be also set manually by executing `mx.autograd.set_recording(True)`, and turning it off after
we no longer want to record operations with `mx.autograd.set_recording(False)`.


## Train mode and Predict Mode

Expand All @@ -76,6 +82,38 @@ Detailed tutorials are available in Part 1 of
[the MXNet gluon book](http://gluon.mxnet.io/).


# Higher order gradient

Some operators support higher order gradients. Meaning that you calculate the gradient of the
gradient. For this the operator's backward must be as well differentiable. Some operators support
differentiating multiple times, and others two, most just once.

For calculating higher order gradients, we can use the `mx.autograd.grad` function while recording
and then call backward, or call `mx.autograd.grad` two times. If we do the later is important that
the first call uses `create_graph=True` and `retain_graph=True` and the second call uses
`create_graph=False` and `retain_graph=True`. Otherwise we will not get the results that we want. If
we would be to recreate the graph in the second call, we would end up with a graph of just the
backward nodes, not the full initial graph that includes the forward nodes.

The idiom to calculate higher order gradients is the following:

```python
import mxnet autograd as ag
with ag.record():
y = f(x)
y_grad = ag.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad_grad = ag.grad(y_grad, x, create_graph=False, retain_graph=True)[0]
```

or

```python
import mxnet autograd as ag
with ag.record():
y = f(x)
y_grad = ag.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad_grad = y_grad.backward()
```



Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def predict_mode():
def mark_variables(variables, gradients, grad_reqs='write'):
"""Mark NDArrays as variables to compute gradient for autograd.
This is equivalent to the function .attach_grad() in a variable, but with this
call we can set the gradient to any value.
Parameters
----------
variables: NDArray or list of NDArray
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,8 @@ def attach_grad(self, grad_req='write', stype=None):
"""Attach a gradient buffer to this NDArray, so that `backward`
can compute gradient with respect to it.
The gradient is initialized to zeros.
Parameters
----------
grad_req : {'write', 'add', 'null'}
Expand Down

0 comments on commit c073cde

Please sign in to comment.