Skip to content

Commit

Permalink
Remove some untested dynamic shapes paths (prep work for stackless).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676529297
  • Loading branch information
dougalm authored and Google-ML-Automation committed Sep 19, 2024
1 parent 5f044a6 commit 63e7b7d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 101 deletions.
103 changes: 3 additions & 100 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,6 @@ def new_instantiated_literal(self, val) -> JaxprTracer:

def new_instantiated_const(self, val) -> JaxprTracer:
aval = get_aval(val)
if isinstance(aval, DShapedArray):
shape = [self.new_instantiated_const(d)
if isinstance(d, Tracer) and d._trace.level < self.level else d
for d in aval.shape]
aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val))

def new_arg(self, pval: PartialVal) -> JaxprTracer:
Expand Down Expand Up @@ -258,15 +253,9 @@ def process_call(self, primitive, f, tracers, params):
# which were unknown to the first call (corresponding to in_avals).

# Wrap f to perform the partial evaluation and plumb out aux data.
if not config.dynamic_shapes.value:
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
tuple(in_avals))
else:
if f.in_type is None:
f = lu.annotate(f, tuple((a, True) for a in in_avals))
f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns),
f.in_type, False)
f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
tuple(in_avals))
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
const_params = update_params(params, in_knowns, 0)

Expand Down Expand Up @@ -569,92 +558,6 @@ def partial_eval_wrapper_nounits(
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env)

@lu.transformation_with_aux
def trace_to_subjaxpr_nounits_dyn(
main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType,
instantiate: bool | Sequence[bool],
*in_consts: Any):
trace = main.with_cur_sublevel()
in_avals, which_explicit = unzip2(in_type)

# To form input tracers from in_type, we need to first build ConstVar tracers
# for all axis sizes, so that we can then use those tracers in the shapes of
# avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly
# type agreement checking via get_referent.
in_consts_full: list[JaxprTracer | None] = [None] * len(in_type)
in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns)
for idx, (aval, explicit) in enumerate(in_type):
if explicit and next(in_knowns_iter):
constval = next(in_consts_iter)
if isinstance(aval, DShapedArray):
for i, d in enumerate(aval.shape):
if isinstance(d, DBIdx):
if in_consts_full[d.val] is None:
in_consts_full[d.val] = \
JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]),
ConstVar(constval.shape[i]))
assert core.same_referent(constval.shape[i], in_consts_full[d.val])
shape = [in_consts_full[d.val] if type(d) is DBIdx else d
for d in aval.shape]
aval = aval.update(shape=tuple(shape))
in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval),
ConstVar(constval))
# Check that we covered all axis sizes with ConstVar tracers.
for idx, (aval, explicit) in enumerate(in_type):
if not explicit: assert in_consts_full[idx] is not None
if isinstance(aval, DShapedArray):
assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None
for d in aval.shape)

# Next, build tracers for all unknown inputs, using the in_consts_full list
# for axis size tracers when necessary.
in_tracers = []
in_knowns_iter = iter(in_knowns)
for aval, explicit in in_type:
if explicit and not next(in_knowns_iter):
if isinstance(aval, DShapedArray):
shape = [in_consts_full[d.val] if type(d) is DBIdx else d
for d in aval.shape]
aval = aval.update(shape=tuple(shape))
tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding())
in_tracers.append(tracer)

# Merge in_consts and in_tracers and call wrapped fn with explicit arguments.
in_args = merge_lists(in_knowns, in_tracers, in_consts)
ans = yield in_args, {}

# Instantiate outputs and build jaxpr.
if isinstance(instantiate, bool):
instantiate = [instantiate] * len(ans)
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t
for inst, t in zip(instantiate, out_tracers)]

# Collect known outputs.
out_knowns: list[bool] = [t.is_known() for t in out_tracers]
out_consts: list[Any] = [t.pval.get_known() for t in out_tracers
if t.is_known()]

# Build the jaxpr.
out_tracers = [t for t in out_tracers if not t.is_known()]
jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers)
out_avals = [v.aval for v in jaxpr.outvars]
idx_map = {v: InDBIdx(i)
for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore
if type(a) is DShapedArray else a, True) for a in out_avals]

# Which residuals are just forwarded inputs? Check obj id, then prune.
id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore
if c is not None}
fwds: list[int | None] = [id_map.get(id(c)) for c in res]
res = tuple(c for c, fwd in zip(res, fwds) if fwd is None)

del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \
in_tracers, in_args, ans, out_tracers, out_avals
yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env)


custom_partial_eval_rules: dict[Primitive, Callable] = {}
call_partial_eval_rules: dict[Primitive, Callable] = {}
call_param_updaters: dict[Primitive, Callable] = {}
Expand Down
1 change: 0 additions & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic,
trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2,
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
trace_to_subjaxpr_nounits_dyn as trace_to_subjaxpr_nounits_dyn,
trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd,
tracers_to_jaxpr as tracers_to_jaxpr,
trivial_ctx as trivial_ctx,
Expand Down

0 comments on commit 63e7b7d

Please sign in to comment.