Skip to content

Commit

Permalink
Merge pull request jax-ml#2654 from google/pfix
Browse files Browse the repository at this point in the history
fix jaxpr invar avals
  • Loading branch information
mattjj authored Apr 9, 2020
2 parents f37f235 + 5f1f29e commit 9191843
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
def getvar(t):
var = t_to_var.get(id(t))
if var is None:
var = newvar(partial_val_aval(*t.pval))
t_to_var[id(t)] = var
aval = t.pval[0] if t.pval[0] is not None else abstract_unit
var = t_to_var[id(t)] = newvar(aval)
return var
sorted_tracers = toposort(out_tracers)
invars = map(getvar, in_tracers)
Expand All @@ -458,8 +458,7 @@ def getvar(t):
def getconstvar(c):
var = const_to_var.get(id(c))
if var is None:
var = newvar(get_aval(c))
const_to_var[id(c)] = var
var = const_to_var[id(c)] = newvar(get_aval(c))
return var
processed_eqn_ids = set()
for t in sorted_tracers:
Expand Down

0 comments on commit 9191843

Please sign in to comment.