Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debug: jax order of parameters can change from input to output #7

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions autograd_minimize/base_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,13 @@ def get_input(self, input_var):
return input_

def get_output(self, output_var):
assert "shapes" in dir(
self
), "You must first call get input to define the tensors shapes."
assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
output_var_ = self._unconcat(output_var, self.shapes)
return output_var_

def get_bounds(self, bounds):

if bounds is not None:
if isinstance(bounds, tuple) and not (
isinstance(bounds[0], tuple) or isinstance(bounds[0], sopt.Bounds)
):
if isinstance(bounds, tuple) and not (isinstance(bounds[0], tuple) or isinstance(bounds[0], sopt.Bounds)):
assert len(bounds) == 2
new_bounds = [bounds] * self.var_num

Expand All @@ -53,23 +48,18 @@ def get_bounds(self, bounds):
if k in bounds.keys():
new_bounds += format_bounds(bounds[k], self.shapes[k])
else:
new_bounds += [(None, None)] * np.prod(
self.shapes[k], dtype=np.int32
)
new_bounds += [(None, None)] * np.prod(self.shapes[k], dtype=np.int32)
else:
new_bounds = bounds
return new_bounds

def get_constraints(self, constraints, method):
if constraints is not None and not isinstance(
constraints, sopt.LinearConstraint
):
if constraints is not None and not isinstance(constraints, sopt.LinearConstraint):
assert isinstance(constraints, dict)
assert "fun" in constraints.keys()
self.ctr_func = constraints["fun"]
use_autograd = constraints.get("use_autograd", True)
if method in ["trust-constr"]:

constraints = sopt.NonlinearConstraint(
self._eval_ctr_func,
lb=constraints.get("lb", -np.inf),
Expand Down Expand Up @@ -128,6 +118,9 @@ def get_ctr_jac(self, input_var):
def _concat(self, ten_vals):
ten = []
if isinstance(ten_vals, dict):
if "shapes" in dir(self):
if not all(map(lambda k: k[1] == k[0], zip(ten_vals.keys(), self.shapes))):
raise AssertionError("Results hould always be in the same order")
shapes = {}
for k, t in ten_vals.items():
if t is not None:
Expand Down Expand Up @@ -163,9 +156,7 @@ def _unconcat(self, ten, shapes):
ten_vals = {}
for k, sh in shapes.items():
next_ind = current_ind + np.prod(sh, dtype=np.int32)
ten_vals[k] = self._reshape(
self._gather(ten, current_ind, next_ind), sh
)
ten_vals[k] = self._reshape(self._gather(ten, current_ind, next_ind), sh)
current_ind = next_ind

elif isinstance(shapes, list) or isinstance(shapes, tuple):
Expand All @@ -175,9 +166,7 @@ def _unconcat(self, ten, shapes):
ten_vals = []
for sh in shapes:
next_ind = current_ind + np.prod(sh, dtype=np.int32)
ten_vals.append(
self._reshape(self._gather(ten, current_ind, next_ind), sh)
)
ten_vals.append(self._reshape(self._gather(ten, current_ind, next_ind), sh))

current_ind = next_ind

Expand All @@ -186,17 +175,14 @@ def _unconcat(self, ten, shapes):

return ten_vals

@abstractmethod
def _reshape(self, t, sh):
return
return np.reshape(t, sh)

@abstractmethod
def _tconcat(self, t_list, dim=0):
return
return np.concatenate(t_list, dim)

@abstractmethod
def _gather(self, t, i, j):
return
return t[i:j]


def format_bounds(bounds_, sh):
Expand Down
9 changes: 6 additions & 3 deletions autograd_minimize/jax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_hvp(self, input_var, vector):
input_var_ = self._unconcat(np.array(input_var, dtype=self.precision), self.shapes)
vector_ = self._unconcat(np.array(vector, dtype=self.precision), self.shapes)

res = self._get_hvp_tf(input_var_, vector_)
res = self._get_hvp(input_var_, vector_)
return onp.array(self._concat(res)[0]).astype(onp.float64)

def get_hess(self, input_var):
Expand All @@ -47,9 +47,12 @@ def _get_hess(self, input_var):

def _get_value_and_grad(self, input_var):
val_grad = jax.value_and_grad(self._eval_func)
return val_grad(input_var)
val, grads = val_grad(input_var)
if isinstance(self.shapes, dict):
grads = {k: grads[k] for k in self.shapes.keys()}
return val, grads

def _get_hvp_tf(self, input_var, vector):
def _get_hvp(self, input_var, vector):
return hvp_fwd_rev(self._eval_func, input_var, vector)

def get_ctr_jac(self, input_var):
Expand Down