-
I am converting code I wrote in Ceres to Jax so that i can debug and tune it easier. The code is very straightforward but takes quite a bit of time to compile. I am looking for advice on how to speed this up. Perhaps I can reformulate my code or use a feature I am not aware of and this will save me some time. Any help appreciated. I am running jax 0.1.72 and jaxlib 0.1.51. I am compiling on for CPU.
|
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
The key to faster compilation time is to avoid the Python Two options you might consider:
|
Beta Was this translation helpful? Give feedback.
-
Aside from the summing of the error, each iteration is independent. But, you could do each iteration interdependently and store the error to a vector which is summed for the final error. It's not clear to me how one would apply vmap here from looking at the examples. Specifically, its not clear to me how I would map k into body_fun() so that vmap could be used. |
Beta Was this translation helpful? Give feedback.
-
I think I figured it out. Sort off... I packaged up the inputs into an dictionary and keep tally of my error.
However, I am calling grad on this function since I am using jax to estimate the jacobian for use in scipy.optimize. Jax is giving me this now.
Am i using for_i correctly? I will try scan now. |
Beta Was this translation helpful? Give feedback.
-
I got it working by using .scan(). I made carry = a dictionary of my input parameters. I stored the error into y. I just used carry as a pass through of the params by having my function return carry, error. I made the input to scan for xs = np.arange(num_measurements). |
Beta Was this translation helpful? Give feedback.
I got it working by using .scan(). I made carry = a dictionary of my input parameters. I stored the error into y. I just used carry as a pass through of the params by having my function return carry, error. I made the input to scan for xs = np.arange(num_measurements).