Struggling with objax, Modules, StateVar, and python classes #262
Replies: 3 comments
-
Weirdly, I can replace the |
Beta Was this translation helpful? Give feedback.
-
Even weirder, I've tracked it down to line 67 in
With |
Beta Was this translation helpful? Give feedback.
-
Definitely some caching/closure issue, if I create a module level function like so: def get_train_op(model):
opt_model = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())
def train_op(_s, _t):
dE, E = energy(_s, _t)
opt_model(0.1, dE)
return objax.Jit(train_op, model.vars() + opt_model.vars()) and then use it to create my self.ops[i] = get_train_op(self.objs[i]) then everything works as expected, I can use the I literally traced the moment it changes the wrong |
Beta Was this translation helpful? Give feedback.
-
I am having trouble understanding how to use objax Modules with python classes, and manually updating StateVars. In the code below, basically I want to have a container holding
num
Test
objects, each of which has their owntrain_op
and Adam optimizer. It's a silly example and not sure the maths makes sense, but it seems to highlight the issue.If I run the code with
num = 1
below, it Jit compiles and runs fine. As soon as I setnum = 2
I get anUnexpectedTracerError
which points to theopt_model(0.1, dE)
line. The documentation seems to say thatStateVar
s are used for manually tracking variables, in my case I have a set of manually updated variables and Adam updated variables that are updated using the same loss function. What am I doing wrong here? (Thefunctools.partial
seems to break things even withnum = 1
).Beta Was this translation helpful? Give feedback.
All reactions