Pure Callbacks do not support JVP. Is there a workaround? #23082
-
Hi, the error message is: Is there a workaround to avoid implementing the custom_jvp function? Next, I have attempted a custom_jvp function as shown in the documentation, but that function is presumably not being called, and the same error message is emitted as above. The steps I followed are:
Am I missing anything else? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
The workaround is to define the Can you say more about what error you're seeing when you try this? Can you share a reproduction of the issue? |
Beta Was this translation helpful? Give feedback.
There are two problems here I think:
(1) you define
cwh
as acustom_jvp
, but you never define its JVP viacwh.defjvp
. Fortunately you define its JVP elsewhere, so you can deletecwh
and just usecompute_weights_host
directly.(2) you use
compute_weights_wrapper
in a setting where it is automatically differentiated, but it callsjax.pure_callback
which is not differentiable. You'll need to define acustom_jvp
rule forcompute_weights_wrapper
to avoid trying to automatically differentiate thepure_callback
.You might find this doc helpful: it contains a worked example of a pure callback using custom_jvp to define autodiff rules: https://jax.readthedocs.io/en/latest/notebooks/external_callba…