Skip to content

Conversation

@andreped
Copy link
Collaborator

@andreped andreped commented Jan 29, 2023

This PR adds JAX backend support to Macenko.

Changes:

  • Implemented Macenko with JAX backend and added it to base
  • Added JAX unit test CI jobs (check that JAX yields similar results to numpy backend)
  • Renamed CI names to better match their actual purpose
  • Updated README regarding JAX backend support
  • Fixed setup.py to support installation through pip install torchstain[jax]
  • Fixed np.float32 deprecation in numpy macenko
  • Removed unwanted numpy import in macenko tf backend

Note that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have experimental JAX support as of now. Here is how JAX backend compared to the other backends:

backends numpy jax torch tf
runtime [s] 0.455 2.427 0.201 0.442

Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.

@andreped andreped mentioned this pull request Jan 29, 2023
@andreped andreped added the enhancement New feature or request label Jan 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant