When trying to use the custom_gradient decorator with a JAX backend as shown in the example located here
custom_gradient_example, I get an error message saying TypeError: 'NoneType' object is not callable.
It works without a problem on the TensorFlow backend, but not JAX. The behaviour is seen both on keras v3.9.0 and on keras-nightly.