Skip to content

custom_gradient not working with JAX backend #21105

@larschristensen

Description

@larschristensen

When trying to use the custom_gradient decorator with a JAX backend as shown in the example located here
custom_gradient_example, I get an error message saying TypeError: 'NoneType' object is not callable.

It works without a problem on the TensorFlow backend, but not JAX. The behaviour is seen both on keras v3.9.0 and on keras-nightly.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions