-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from brunorigal/add_jax
Add jax
- Loading branch information
Showing
6 changed files
with
160 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# autograd-minimize | ||
|
||
autograd-minimize is a wrapper around the minimize routine of scipy which uses the autograd capacities of | ||
tensorflow or pytorch to compute automatically the gradients, | ||
jax, tensorflow or pytorch to compute automatically the gradients, | ||
hessian vector products and hessians. | ||
|
||
It also accepts functions of more than one variables as input. | ||
|
@@ -31,6 +31,7 @@ But you can also use pytorch: | |
``` | ||
import torch | ||
from autograd_minimize import minimize | ||
import numpy as np | ||
def rosen_torch(x): | ||
return (100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0).sum() | ||
|
@@ -40,6 +41,17 @@ print(res.x) | |
>>> array([0.99999912, 0.99999824]) | ||
``` | ||
|
||
Or jax: | ||
``` | ||
import numpy as np | ||
from autograd_minimize import minimize | ||
rosen_jax=lambda x: (100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0).sum() | ||
res = minimize(rosen_jax, np.array([0.,0.]), backend='jax') | ||
print(res.x) | ||
>>> array([0.99999912, 0.99999824]) | ||
``` | ||
|
||
You can also try other optimization methods such as Newton-CG which uses | ||
automatic computation of the hessian vector product (hvp). Let's as well | ||
increase the precision of hvp and gradient computation to float64 and the tolerance to 1e-8: | ||
|
@@ -70,7 +82,7 @@ U = random((shape[0], inner_shape)) | |
V = random((inner_shape, shape[1])) | ||
prod = U@V | ||
def mat_fac(U=None, V=None): | ||
def mat_fac(U, V): | ||
return tf.reduce_mean(([email protected](prod, dtype=tf.float32))**2) | ||
x0 = {'U': -random((shape[0], inner_shape)), 'V': random((inner_shape, shape[1]))} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import numpy as np | ||
import jax | ||
from .base_wrapper import BaseWrapper | ||
import jax.numpy as np | ||
import numpy as onp | ||
|
||
|
||
class JaxWrapper(BaseWrapper): | ||
def __init__(self, func, precision: str = "float32"): | ||
self.func = func | ||
|
||
if precision == "float32": | ||
self.precision = np.float32 | ||
elif precision == "float64": | ||
self.precision = np.float64 | ||
else: | ||
raise ValueError | ||
|
||
def get_value_and_grad(self, input_var): | ||
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." | ||
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes) | ||
|
||
value, grads = self._get_value_and_grad(input_var_) | ||
|
||
return [ | ||
onp.array(value).astype(onp.float64), | ||
onp.array(self._concat(grads)[0]).astype(onp.float64), | ||
] | ||
|
||
def get_hvp(self, input_var, vector): | ||
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." | ||
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes) | ||
vector_ = self._unconcat(np.array(vector, dtype=self.precision), self.shapes) | ||
|
||
res = self._get_hvp_tf(input_var_, vector_) | ||
return onp.array(self._concat(res)[0]).astype(onp.float64) | ||
|
||
def get_hess(self, input_var): | ||
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." | ||
input_var_ = np.array(input_var, dtype=self.precision) | ||
hess = onp.array(self._get_hess(input_var_)).astype(onp.float64) | ||
|
||
return hess | ||
|
||
def _get_hess(self, input_var): | ||
return jax.hessian(self._eval_func)(self._unconcat(input_var, self.shapes)) | ||
|
||
def _get_value_and_grad(self, input_var): | ||
val_grad = jax.value_and_grad(self._eval_func) | ||
return val_grad(input_var) | ||
|
||
def _get_hvp_tf(self, input_var, vector): | ||
return hvp_fwd_rev(self._eval_func, input_var, vector) | ||
|
||
def get_ctr_jac(self, input_var): | ||
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." | ||
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes) | ||
|
||
jac = self._get_ctr_jac(input_var_) | ||
|
||
return onp.array(jac).reshape((-1, self.var_num)).astype(onp.float64) | ||
|
||
def _get_ctr_jac(self, input_var): | ||
return jax.jacfwd(self._eval_ctr_func)(input_var) | ||
|
||
def _reshape(self, t, sh): | ||
if isinstance(t, onp.ndarray) or isinstance(t, np.ndarray): | ||
return np.reshape(t, sh) | ||
else: | ||
raise NotImplementedError | ||
|
||
def _tconcat(self, t_list, dim=0): | ||
if isinstance(t_list[0], onp.ndarray) or isinstance(t_list[0], np.ndarray): | ||
return np.concatenate(t_list, dim) | ||
else: | ||
raise NotImplementedError | ||
|
||
def _gather(self, t, i, j): | ||
if isinstance(t, onp.ndarray) or isinstance(t, np.ndarray): | ||
return t[i:j] | ||
elif i + 1 == j: | ||
return t | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
# from: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#hessian-vector-products-using-both-forward-and-reverse-mode | ||
|
||
|
||
# reverse-mode | ||
def hvp(f, x, v): | ||
return jax.grad(lambda x: np.vdot(jax.grad(f)(x), v))(x) | ||
|
||
|
||
# forward-over-reverse | ||
def hvp_fwd_rev(f, primals, tangents): | ||
return jax.jvp(jax.grad(f), [primals], [tangents])[1] | ||
|
||
|
||
# reverse-over-forward | ||
def hvp_revfwd(f, primals, tangents): | ||
g = lambda primals: jax.jvp(f, [primals], [tangents])[1] | ||
return jax.grad(g)(primals) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters