How to use jax2tf with ffi? #26004
Unanswered
krzysztofrusek
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I tried to used
jax2tf
on a function usingjnp.linalg.svd
and it cannot be used in tf due to missing ffi call to lapack.Is there a way around? Maybe disabling lowering to ffi call in
svd
?Code example
when I change
jnp.linalg.svd
to e.g.jnp.sum
this works.I believe there is no issue on GPU as svd is lowered to xla op.
I am looking for a simple solution like getting pre-ffi implementation of svd instead of lapac call.
FYI 1. the code is going to be executed in java with TensorFLow binding.
FYI2. We planed to implement all numerical algorithms in jax and use saved model to bring them to application so there may be another opps that get lowered to ffi call.
Beta Was this translation helpful? Give feedback.
All reactions