Skip to content

Commit

Permalink
[attrs] simplify input side of jvp internals
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 22, 2024
1 parent 3cf7ca8 commit 67572d3
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions jax/experimental/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,36 +81,40 @@ def _setattr_staging(trace, tracer, *, obj, attr):
pe.DynamicJaxprTrace.process_setattr = _setattr_staging


def jvp(f, primals, tangents, tangent_attrs_in):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree_ = tree_flatten(tangents)
def jvp(f, primals, tangents, attr_tangents):
attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents)
attr_primals = tuple(jax_getattr(o, a) for o, a in attrs)
primals_flat, in_tree = tree_flatten((attr_primals, primals))
tangents_flat, in_tree_ = tree_flatten((attr_tangents, tangents))
if in_tree != in_tree_: raise Exception
f_, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree)
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
primals_flat, tangents_flat, tangent_attrs_in)
primals_flat, tangents_flat)
out_primals = tree_unflatten(out_tree(), out_primals_flat)
out_tangents = tree_unflatten(out_tree(), out_tangents_flat)
return out_primals, out_tangents, tangent_attrs_out

@lu.transformation
def _set_attrs(attrs, attr_vals, args):
for (o, a), x in zip(attrs, attr_vals):
jax_setattr(o, a, x)
yield (yield args, {})

def _jvp(fun: lu.WrappedFun):
return jvpfun2(jvp_subtrace2(fun))

@lu.transformation
def jvpfun2(primals, tangents, tangent_attrs_in):
def jvpfun2(primals, tangents):
with core.new_main(ad.JVPTrace) as main:
out_primals, out_tangents, tangent_attrs_out = \
yield (main, primals, tangents, tangent_attrs_in), {}
yield (main, primals, tangents), {}
del main
yield out_primals, out_tangents, tangent_attrs_out

@lu.transformation
def jvp_subtrace2(main, primals, tangents, tangent_attrs_in):
def jvp_subtrace2(main, primals, tangents):
main.attrs_tracked = [] # attrs written to
trace = main.with_cur_sublevel()
for obj, name, tangent in tangent_attrs_in:
primal = jax_getattr(obj, name)
tracer = ad.JVPTracer(trace, primal, tangent)
jax_setattr(obj, name, tracer)
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
Expand Down

0 comments on commit 67572d3

Please sign in to comment.