You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to write a DP-SGD using JAX, and one of the steps involves using vmap to clip the gradients generated by each example in a batch. However, I am encountering an error message that says "trace leak." My error message is as follows:
Error message
Traceback (most recent call last):
File "dp_train.py", line 404, in
app.run(main)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "dp_train.py", line 377, in main
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
File "dp_train.py", line 111, in train
self.train_step(summary, next(train_iter), progress)
File "dp_train.py", line 83, in train_step
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 257, in call
output, changes = self._call(self.vc.tensors(), kwargs, *args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 247, in jit
return f(*args, **kwargs), self.vc.tensors()
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 183, in call
return self.wrapped(*args, **kwargs)
File "dp_train.py", line 207, in train_op
private_gradients, v= private_grad((x,y), self.params.seed, self.params.l2_norm_clip, self.params.noise_multiplier, self.params.batch)
File "dp_train.py", line 182, in private_grad
clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/api.py", line 1240, in vmap_f
out_flat = batching.batch(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 203, in call_wrapped
ans = gen.send(ans)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/interpreters/batching.py", line 588, in _batch_outer
del main
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/contextlib.py", line 120, in exit
next(self.gen)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/core.py", line 1106, in new_main
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(2,BatchTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544068528 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544068528> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544065968 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544065968> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209543950144 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209543950144> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209543947584 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209543947584> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544363760 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544363760> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544361200 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544361200> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544253248 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544253248> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544251008 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544251008> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544658672 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544658672> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544656432 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544656432> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544569376 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544569376> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544542144 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544542144> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544987088 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544987088> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544959856 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544959856> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545233568 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545233568> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545204976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545204976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545501168 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545501168> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545482128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545482128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545378688 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545378688> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545376128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545376128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545268016 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545268016> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545265456 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545265456> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545669824 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545669824> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545667264 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545667264> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545555056 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545555056> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545552496 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545552496> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545964976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545964976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545962416 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545962416> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545846112 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545846112> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545843552 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545843552> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546117904 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546117904> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546091952 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546091952> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546349360 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546349360> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546336992 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546336992> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546742256 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546742256> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546725792 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546725792> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546615440 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546615440> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546598976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546598976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547008336 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547008336> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546987776 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546987776> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546881520 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546881520> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546865056 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546865056> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547278512 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547278512> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547257952 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547257952> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547163984 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547163984> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547143424 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547143424> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547403888 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547403888> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547386128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547386128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,16,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,16,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209549227808 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209549227808> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,16,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,16,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209549215200 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209549215200> is referred to by <tuple 140209550883136>[0]
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Code is below:
class TrainLoop(objax.Module):
"""
Training loop for general machine learning models.
Based on the training loop from the objax CIFAR10 example code.
"""
predict: Callable
train_op: Callable
def __init__(self, nclass: int, **kwargs):
self.nclass = nclass
self.params = EasyDict(kwargs)
def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
with checking_leaks():
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
print('done')
for k, v in kv.items():
if jnp.isnan(v):
raise ValueError('NaN, try reducing learning rate', k)
if summary is not None:
summary.scalar(k, float(v))
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100,
patience=None):
"""
Completely standard training. Nothing interesting to see here.
"""
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
start_epoch, last_ckpt = checkpoint.restore(self.vars())
train_iter = iter(train)
progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
best_acc = 0
best_acc_epoch = -1
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
for epoch in range(start_epoch, num_train_epochs):
# Train
summary = Summary()
loop = range(0, train_size, self.params.batch)
for step in loop:
progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
self.train_step(summary, next(train_iter), progress)
# Eval
accuracy, total = 0, 0
if epoch % FLAGS.eval_steps == 0 and test is not None:
for data in test:
total += data['image'].shape[0]
preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
accuracy += (preds == data['label'].numpy()).sum()
accuracy /= total
summary.scalar('eval/accuracy', 100 * accuracy)
tensorboard.write(summary, step=(epoch + 1) * train_size)
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
summary['eval/accuracy']()))
if summary['eval/accuracy']() > best_acc:
best_acc = summary['eval/accuracy']()
best_acc_epoch = epoch
elif patience is not None and epoch > best_acc_epoch + patience:
print("early stopping!")
checkpoint.save(self.vars(), epoch + 1)
return
else:
print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
if epoch % save_steps == save_steps - 1:
checkpoint.save(self.vars(), epoch + 1)
class MemModule(TrainLoop):
def init(self, model: Callable, nclass: int, mnist=False, **kwargs):
"""
Completely standard training. Nothing interesting to see here.
"""
super().init(nclass, **kwargs)
self.model = model(1 if mnist else 3, nclass)
self.opt = objax.optimizer.Momentum(self.model.vars())
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
@objax.Function.with_vars(self.model.vars())
def loss(x, label):
logit = self.model(x, training=True)
loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
gv = objax.GradValues(loss, self.model.vars()) # gv(x,y): return the averaged gradient of batch x
self.gv = gv
@objax.Function.with_vars(self.vars())
def clipped_grad(l2_norm_clip, single_example_batch):
"""Evaluate gradient for a single-example batch and clip its grad norm."""
grads, single_v = gv(jnp.array([single_example_batch[0]]), jnp.array([single_example_batch[1]]))
nonempty_grads, tree_def = tree_flatten(
grads)
total_grad_norm = jnp.linalg.norm(
jnp.array([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads]))
divisor = jnp.maximum(total_grad_norm / l2_norm_clip, 1.)
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
return tree_unflatten(tree_def, normalized_nonempty_grads), single_v
def private_grad(batch, seed, l2_norm_clip, noise_multiplier,
batch_size):
"""Return differentially private gradients for params, evaluated on batch."""
clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)
clipped_grads_flat, grads_treedef = tree_flatten(clipped_grads)
aggregated_clipped_grads = [g.sum(0) for g in clipped_grads_flat]
rngs = random.split(random.key(seed), len(aggregated_clipped_grads))
noised_aggregated_clipped_grads = [
g + l2_norm_clip * noise_multiplier * random.normal(r, g.shape)
for r, g in zip(rngs, aggregated_clipped_grads)]
normalized_noised_aggregated_clipped_grads = [
g / batch_size for g in noised_aggregated_clipped_grads]
v_loss = {}
keys = vs[1].keys()
for key in keys:
v_loss[key] = jnp.mean(vs[1][key])
return tree_unflatten(grads_treedef, normalized_noised_aggregated_clipped_grads), v_loss
@objax.Function.with_vars(self.vars())
def train_op(progress, x, y):
private_gradients, v= private_grad((x,y), self.params.seed, self.params.l2_norm_clip, self.params.noise_multiplier, self.params.batch)
lr = self.params.lr * jnp.cos(progress * (7 * jnp.pi) / ( 2 * 8))
lr = lr * jnp.clip(progress * 100, 0, 1)
self.opt(lr, private_gradients)
self.model_ema.update_ema()
return {'monitors/lr': lr, **v}
self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
self.train_op = objax.Jit(train_op)`
I have located the error occurring at the line where the code runs to "clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)". However, I don't know how to modify my code. I sincerely hope to get some help.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am trying to write a DP-SGD using JAX, and one of the steps involves using vmap to clip the gradients generated by each example in a batch. However, I am encountering an error message that says "trace leak." My error message is as follows:
Error message
Traceback (most recent call last):
File "dp_train.py", line 404, in
app.run(main)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "dp_train.py", line 377, in main
tm.train(FLAGS.epochs, len(xs), train, test, logdir,
File "dp_train.py", line 111, in train
self.train_step(summary, next(train_iter), progress)
File "dp_train.py", line 83, in train_step
kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 257, in call
output, changes = self._call(self.vc.tensors(), kwargs, *args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 247, in jit
return f(*args, **kwargs), self.vc.tensors()
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/objax/module.py", line 183, in call
return self.wrapped(*args, **kwargs)
File "dp_train.py", line 207, in train_op
private_gradients, v= private_grad((x,y), self.params.seed, self.params.l2_norm_clip, self.params.noise_multiplier, self.params.batch)
File "dp_train.py", line 182, in private_grad
clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/api.py", line 1240, in vmap_f
out_flat = batching.batch(
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/linear_util.py", line 203, in call_wrapped
ans = gen.send(ans)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/interpreters/batching.py", line 588, in _batch_outer
del main
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/contextlib.py", line 120, in exit
next(self.gen)
File "/home/ubuntu/anaconda3/envs/lira/lib/python3.8/site-packages/jax/_src/core.py", line 1106, in new_main
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(2,BatchTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544068528 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544068528> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544065968 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544065968> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209543950144 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209543950144> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209543947584 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209543947584> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544363760 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544363760> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544361200 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544361200> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544253248 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544253248> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544251008 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544251008> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544658672 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544658672> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544656432 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544656432> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544569376 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544569376> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544542144 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544542144> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544987088 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544987088> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209544959856 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209544959856> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545233568 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545233568> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,128,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,128,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545204976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545204976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545501168 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545501168> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545482128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545482128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545378688 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545378688> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545376128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545376128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545268016 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545268016> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545265456 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545265456> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545669824 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545669824> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545667264 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545667264> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545555056 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545555056> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545552496 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545552496> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545964976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545964976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545962416 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545962416> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545846112 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545846112> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209545843552 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209545843552> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546117904 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546117904> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,64,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,64,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546091952 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546091952> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546349360 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546349360> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546336992 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546336992> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546742256 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546742256> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546725792 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546725792> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546615440 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546615440> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546598976 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546598976> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547008336 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547008336> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546987776 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546987776> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546881520 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546881520> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209546865056 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209546865056> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547278512 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547278512> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547257952 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547257952> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547163984 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547163984> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547143424 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547143424> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547403888 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547403888> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,32,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,32,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209547386128 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209547386128> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,16,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,16,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209549227808 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209549227808> is referred to by <tuple 140209550883136>[0]
Traced<ShapedArray(float32[1,16,1,1])>with<BatchTrace(level=2/0)> with
val = Traced<ShapedArray(float32[256,1,16,1,1])>with<DynamicJaxprTrace(level=1/0)>
batch_dim = 0
This BatchTracer with object id 140209549215200 was created on line:
dp_train.py:154 (loss)
<BatchTracer 140209549215200> is referred to by <tuple 140209550883136>[0]
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Code is below:
class TrainLoop(objax.Module):
"""
Training loop for general machine learning models.
Based on the training loop from the objax CIFAR10 example code.
"""
predict: Callable
train_op: Callable
class MemModule(TrainLoop):
def init(self, model: Callable, nclass: int, mnist=False, **kwargs):
"""
Completely standard training. Nothing interesting to see here.
"""
super().init(nclass, **kwargs)
self.model = model(1 if mnist else 3, nclass)
self.opt = objax.optimizer.Momentum(self.model.vars())
self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
I have located the error occurring at the line where the code runs to "clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)". However, I don't know how to modify my code. I sincerely hope to get some help.
Beta Was this translation helpful? Give feedback.
All reactions