-
Is it possible to serialize function that takes custom pytree as an argument? This won't work:
This produces:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can't do this directly: exported code needs to be loadable in a different python environment, and there's no good way to serialize your class definition in a way that can be loaded in another environment (short of using something like pickle, which has other well-documented issues). A workaround would be to define your function in terms of the flattened model buffers, and then reconstruct it inside your function: model = Model(1)
model_flat, model_treedef = jax.tree.flatten(model)
@jax.jit
def func2(flattened_model):
model = jax.tree.unflatten(model_treedef, flattened_model)
return model.x
exported: export.Exported = export.export(func2)(model_flat)
blob: bytearray = exported.serialize() |
Beta Was this translation helpful? Give feedback.
You can't do this directly: exported code needs to be loadable in a different python environment, and there's no good way to serialize your class definition in a way that can be loaded in another environment (short of using something like pickle, which has other well-documented issues).
A workaround would be to define your function in terms of the flattened model buffers, and then reconstruct it inside your function: