Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax #5

Merged
merged 7 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# autograd-minimize

autograd-minimize is a wrapper around the minimize routine of scipy which uses the autograd capacities of
tensorflow or pytorch to compute automatically the gradients,
jax, tensorflow or pytorch to compute automatically the gradients,
hessian vector products and hessians.

It also accepts functions of more than one variables as input.
Expand Down Expand Up @@ -31,6 +31,7 @@ But you can also use pytorch:
```
import torch
from autograd_minimize import minimize
import numpy as np

def rosen_torch(x):
return (100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0).sum()
Expand All @@ -40,6 +41,17 @@ print(res.x)
>>> array([0.99999912, 0.99999824])
```

Or jax:
```
import numpy as np
from autograd_minimize import minimize

rosen_jax=lambda x: (100.0*(x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0).sum()
res = minimize(rosen_jax, np.array([0.,0.]), backend='jax')
print(res.x)
>>> array([0.99999912, 0.99999824])
```

You can also try other optimization methods such as Newton-CG which uses
automatic computation of the hessian vector product (hvp). Let's as well
increase the precision of hvp and gradient computation to float64 and the tolerance to 1e-8:
Expand Down Expand Up @@ -70,7 +82,7 @@ U = random((shape[0], inner_shape))
V = random((inner_shape, shape[1]))
prod = U@V

def mat_fac(U=None, V=None):
def mat_fac(U, V):
return tf.reduce_mean(([email protected](prod, dtype=tf.float32))**2)

x0 = {'U': -random((shape[0], inner_shape)), 'V': random((inner_shape, shape[1]))}
Expand Down
103 changes: 103 additions & 0 deletions autograd_minimize/jax_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import jax
from .base_wrapper import BaseWrapper
import jax.numpy as np
import numpy as onp


class JaxWrapper(BaseWrapper):
def __init__(self, func, precision: str = "float32"):
self.func = func

if precision == "float32":
self.precision = np.float32
elif precision == "float64":
self.precision = np.float64
else:
raise ValueError

def get_value_and_grad(self, input_var):
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes)

value, grads = self._get_value_and_grad(input_var_)

return [
onp.array(value).astype(onp.float64),
onp.array(self._concat(grads)[0]).astype(onp.float64),
]

def get_hvp(self, input_var, vector):
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes)
vector_ = self._unconcat(np.array(vector, dtype=self.precision), self.shapes)

res = self._get_hvp_tf(input_var_, vector_)
return onp.array(self._concat(res)[0]).astype(onp.float64)

def get_hess(self, input_var):
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
input_var_ = np.array(input_var, dtype=self.precision)
hess = onp.array(self._get_hess(input_var_)).astype(onp.float64)

return hess

def _get_hess(self, input_var):
return jax.hessian(self._eval_func)(self._unconcat(input_var, self.shapes))

def _get_value_and_grad(self, input_var):
val_grad = jax.value_and_grad(self._eval_func)
return val_grad(input_var)

def _get_hvp_tf(self, input_var, vector):
return hvp_fwd_rev(self._eval_func, input_var, vector)

def get_ctr_jac(self, input_var):
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes)

jac = self._get_ctr_jac(input_var_)

return onp.array(jac).reshape((-1, self.var_num)).astype(onp.float64)

def _get_ctr_jac(self, input_var):
return jax.jacfwd(self._eval_ctr_func)(input_var)

def _reshape(self, t, sh):
if isinstance(t, onp.ndarray) or isinstance(t, np.ndarray):
return np.reshape(t, sh)
else:
raise NotImplementedError

def _tconcat(self, t_list, dim=0):
if isinstance(t_list[0], onp.ndarray) or isinstance(t_list[0], np.ndarray):
return np.concatenate(t_list, dim)
else:
raise NotImplementedError

def _gather(self, t, i, j):
if isinstance(t, onp.ndarray) or isinstance(t, np.ndarray):
return t[i:j]
elif i + 1 == j:
return t
else:
raise NotImplementedError


# from: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#hessian-vector-products-using-both-forward-and-reverse-mode


# reverse-mode
def hvp(f, x, v):
return jax.grad(lambda x: np.vdot(jax.grad(f)(x), v))(x)


# forward-over-reverse
def hvp_fwd_rev(f, primals, tangents):
return jax.jvp(jax.grad(f), [primals], [tangents])[1]


# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jax.jvp(f, [primals], [tangents])[1]
return jax.grad(g)(primals)
12 changes: 6 additions & 6 deletions autograd_minimize/scipy_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def minimize(
elif backend == "torch":
from .torch_wrapper import TorchWrapper

wrapper = TorchWrapper(
fun, precision=precision, hvp_type=hvp_type, device=torch_device
)
wrapper = TorchWrapper(fun, precision=precision, hvp_type=hvp_type, device=torch_device)
elif backend == "jax":
from .jax_wrapper import JaxWrapper

wrapper = JaxWrapper(fun, precision=precision)
else:
raise NotImplementedError

Expand All @@ -113,9 +115,7 @@ def minimize(
wrapper.get_input(x0),
method=method,
jac=True,
hessp=wrapper.get_hvp
if method in ["Newton-CG", "trust-ncg", "trust-krylov", "trust-constr"]
else None,
hessp=wrapper.get_hvp if method in ["Newton-CG", "trust-ncg", "trust-krylov", "trust-constr"] else None,
hess=wrapper.get_hess if method in ["dogleg", "trust-exact"] else None,
bounds=wrapper.get_bounds(bounds),
constraints=wrapper.get_constraints(constraints, method),
Expand Down
39 changes: 10 additions & 29 deletions autograd_minimize/torch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,15 @@ def __init__(
self.hvp_func = hvp if hvp_type == "hvp" else vhp

def get_value_and_grad(self, input_var):
assert "shapes" in dir(
self
), "You must first call get input to define the tensors shapes."
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."

input_var_ = self._unconcat(
torch.tensor(
input_var, dtype=self.precision, requires_grad=True, device=self.device
),
torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device),
self.shapes,
)

loss = self._eval_func(input_var_)
input_var_grad = (
input_var_.values() if isinstance(input_var_, dict) else input_var_
)
input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_
grads = torch.autograd.grad(loss, input_var_grad)

if isinstance(input_var_, dict):
Expand All @@ -60,17 +54,13 @@ def get_value_and_grad(self, input_var):
]

def get_hvp(self, input_var, vector):
assert "shapes" in dir(
self
), "You must first call get input to define the tensors shapes."
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."

input_var_ = self._unconcat(
torch.tensor(input_var, dtype=self.precision, device=self.device),
self.shapes,
)
vector_ = self._unconcat(
torch.tensor(vector, dtype=self.precision, device=self.device), self.shapes
)
vector_ = self._unconcat(torch.tensor(vector, dtype=self.precision, device=self.device), self.shapes)

if isinstance(input_var_, dict):
input_var_ = tuple(input_var_.values())
Expand All @@ -87,9 +77,7 @@ def get_hvp(self, input_var, vector):
return self._concat(vhp_res)[0].cpu().detach().numpy().astype(np.float64)

def get_hess(self, input_var):
assert "shapes" in dir(
self
), "You must first call get input to define the tensors shapes."
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
input_var_ = torch.tensor(input_var, dtype=self.precision, device=self.device)

def func(inp):
Expand All @@ -100,23 +88,16 @@ def func(inp):
return hess.cpu().detach().numpy().astype(np.float64)

def get_ctr_jac(self, input_var):
assert "shapes" in dir(
self
), "You must first call get input to define the tensors shapes."
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."

input_var_ = self._unconcat(
torch.tensor(
input_var, dtype=self.precision, requires_grad=True, device=self.device
),
torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device),
self.shapes,
)

ctr_val = self._eval_ctr_func(input_var_)
input_var_grad = (
input_var_.values() if isinstance(input_var_, dict) else input_var_
)
grads = torch.autograd.grad(ctr_val, input_var_grad)

input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_
grads = torch.autograd.functional.jacobian(self._eval_ctr_func, input_var_grad)
return grads.cpu().detach().numpy().astype(np.float64)

def _reshape(self, t, sh):
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ numpy==1.23.5
pandas==1.5.3
scipy==1.10.0
setuptools==65.5.1
tensorflow==2.11.0
torch==1.13.1
tensorflow==2.13.0
torch==2.0.1
jax==0.4.12
jaxlib==0.4.12
58 changes: 23 additions & 35 deletions test/test_lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from time import time
import jax

jax.config.update("jax_enable_x64", True)

import numpy as np
import tensorflow as tf
Expand All @@ -12,22 +15,25 @@
from autograd_minimize import minimize
from autograd_minimize.tf_wrapper import tf_function_factory
from autograd_minimize.torch_wrapper import torch_function_factory
import pytest


def rosen_tst(backend="torch"):
@pytest.mark.parametrize("backend", ["tf", "torch", "jax"])
def test_rosen(backend):
"""
Adapated from: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
"""

def rosen_tf(x):
return tf.reduce_sum(
100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0
)
return tf.reduce_sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

def rosen_torch(x):
return (100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0).sum()

func = rosen_tf if backend == "tf" else rosen_torch
if backend == "tf":
func = rosen_tf
else:
func = rosen_torch
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])

for method in [
Expand All @@ -46,42 +52,34 @@ def rosen_torch(x):
"trust-exact", # requires hessian
"trust-krylov",
]:

tic = time()
res = minimize(
func, x0, backend=backend, precision="float64", method=method, tol=1e-8
)
res = minimize(func, x0, backend=backend, precision="float64", method=method, tol=1e-8)

print(method, time() - tic, np.mean(res.x - 1))
assert_almost_equal(res.x, 1, decimal=5)


def test_rosen_tf():
rosen_tst("tf")


def test_rosen_torch():
rosen_tst("torch")


def test_cstr_opt():
@pytest.mark.parametrize("backend", ["tf", "torch", "jax"])
def test_cstr_opt(backend):
"""
Adapated from: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
"""

def fun(x):
return (x[0] - 1) ** 2 + (x[1] - 2.5) ** 2

if backend in ["tf", "jax"]:
fun_ctr = lambda x: np.array([1, -1, -1]) * x[0] + np.array([-2, -2, +2]) * x[1] + np.array([2, 6, 2])
else:
fun_ctr = lambda x: torch.tensor([1, -1, -1]) * x[0] + torch.tensor([-2, -2, +2]) * x[1] + torch.tensor([2, 6, 2])
cons = {
"type": "ineq",
"fun": lambda x: np.array([1, -1, -1]) * x[0]
+ np.array([-2, -2, +2]) * x[1]
+ np.array([2, 6, 2]),
"fun": fun_ctr,
}

bnds = ((0, None), (0, None))

res = minimize(fun, np.array([2, 0]), method="SLSQP", bounds=bnds, constraints=cons)
res = minimize(fun, np.array([2, 0]), method="SLSQP", bounds=bnds, backend=backend, constraints=cons)

assert_almost_equal(res.x, np.array([1.4, 1.7]), decimal=6)

Expand All @@ -95,13 +93,7 @@ def model(U=None, V=None):
return tf.reduce_mean((U @ V - tf.constant(prod, dtype=tf.float32)) ** 2)

def model_torch(smv=None, smp=None):
return (
(
smv[:, None, :, None] * smp[None, :, None, :]
- torch.tensor(prod, dtype=torch.float32)
)
** 2
).mean()
return ((smv[:, None, :, None] * smp[None, :, None, :] - torch.tensor(prod, dtype=torch.float32)) ** 2).mean()

x0 = {"U": -random((shape[0], inner_shape)), "V": random((inner_shape, shape[1]))}

Expand Down Expand Up @@ -139,9 +131,7 @@ def n_knapsack(

# We create knapsacks with attribution of the items to knapsacks [0,1,2,3,4] as:
# [0 1 2 3 4 0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]
capacity_knapsacks = weights_.reshape((n_weights_per_items, -1, n_knapsacks)).sum(
-2
)
capacity_knapsacks = weights_.reshape((n_weights_per_items, -1, n_knapsacks)).sum(-2)

if backend == "tf":
weights_ = tf.constant(weights_, tf.float32)
Expand All @@ -161,9 +151,7 @@ def func(W):
else:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights_ = torch.tensor(weights_, dtype=torch.float32, device=dev)
capacity_knapsacks_ = torch.tensor(
capacity_knapsacks, dtype=torch.float32, device=dev
)
capacity_knapsacks_ = torch.tensor(capacity_knapsacks, dtype=torch.float32, device=dev)

def func(W):
# We use softmax to impose the constraint that the attribution of items to knapsacks should sum to one
Expand Down