Skip to content

exporting a function that takes custom pytree as an argument #22067

Answered by jakevdp
frskplis asked this question in General
Discussion options

You must be logged in to vote

You can't do this directly: exported code needs to be loadable in a different python environment, and there's no good way to serialize your class definition in a way that can be loaded in another environment (short of using something like pickle, which has other well-documented issues).

A workaround would be to define your function in terms of the flattened model buffers, and then reconstruct it inside your function:

model = Model(1)
model_flat, model_treedef = jax.tree.flatten(model)

@jax.jit
def func2(flattened_model):
  model = jax.tree.unflatten(model_treedef, flattened_model)
  return model.x

exported: export.Exported = export.export(func2)(model_flat)
blob: bytearray = exported.se…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@frskplis
Comment options

Answer selected by frskplis
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants