Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions export/orbax/export/export_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ExportManager:

def __init__(
self,
module: jax_module.JaxModule,
module: jax_module.JaxModule | None,
serving_configs: Sequence[osc.ServingConfig],
):
"""ExportManager constructor.
Expand All @@ -45,9 +45,12 @@ def __init__(
serving_configs: a sequence of which each element is a `ServingConfig`
cooresponding to a serving signature of the exported SavedModel.
"""
self._version = module.export_version
self._jax_module = module
if self._version == constants.ExportModelType.ORBAX_MODEL:
if (
not self._jax_module
or self._jax_module.export_version
== constants.ExportModelType.ORBAX_MODEL
):
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
Expand All @@ -59,7 +62,11 @@ def __init__(
@property
def tf_module(self) -> tf.Module:
"""Returns the tf.module maintained by the export manager."""
if self._version == constants.ExportModelType.ORBAX_MODEL:
if (
not self._jax_module
or self._jax_module.export_version
== constants.ExportModelType.ORBAX_MODEL
):
raise TypeError(
'tf_module is not implemented for export version'
' ExportModelType.ORBAX_MODEL.'
Expand Down
2 changes: 2 additions & 0 deletions export/orbax/export/export_manager_obm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

from absl.testing import absltest
from absl.testing import parameterized

from orbax.export import constants
from orbax.export import export_manager
from orbax.export import export_testing_utils
from orbax.export import oex_orchestration
from orbax.export import serving_config as sc
import tensorflow as tf


Expand Down
2 changes: 1 addition & 1 deletion export/orbax/export/obm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ObmExport(export_base.ExportBase):

def __init__(
self,
module: jax_module.JaxModule,
module: jax_module.JaxModule | None,
serving_configs: Sequence[osc.ServingConfig],
):
"""Initializes the ObmExport class."""
Expand Down
Loading