Replies: 1 comment 1 reply
-
Thanks for the report! Leaked tracers generally come from cacheing traced values within a JIT-compiled execution. The simplest version might look something like this: class Foo:
def func(self, x):
self.x = x
return x
f = Foo()
jax.jit(f.func)(1) Your code doesn't do any of this kind of cacheing explicitly, but it happens implicitly via your use of I think the fix would be to avoid using the Best of luck! |
Beta Was this translation helpful? Give feedback.
-
I am trying to write a DP-SGD using JAX, and one of the steps involves using vmap to clip the gradients generated by each example in a batch. However, I am encountering an error message that says "trace leak." My error message is as follows:
Error message
Code is below:
I have located the error occurring at the line where the code runs to "clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)". However, I don't know how to modify my code. I sincerely hope to get some help.
Beta Was this translation helpful? Give feedback.
All reactions