Skip to content

Commit 9c2045d

Browse files
hejiang0116Orbax Authors
authored andcommitted
Make jax_module optional in ObmExport.
PiperOrigin-RevId: 825586752
1 parent 2976643 commit 9c2045d

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

export/orbax/export/export_manager.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ExportManager:
3535

3636
def __init__(
3737
self,
38-
module: jax_module.JaxModule,
38+
module: jax_module.JaxModule | None,
3939
serving_configs: Sequence[osc.ServingConfig],
4040
):
4141
"""ExportManager constructor.
@@ -45,9 +45,12 @@ def __init__(
4545
serving_configs: a sequence of which each element is a `ServingConfig`
4646
cooresponding to a serving signature of the exported SavedModel.
4747
"""
48-
self._version = module.export_version
4948
self._jax_module = module
50-
if self._version == constants.ExportModelType.ORBAX_MODEL:
49+
if (
50+
not self._jax_module
51+
or self._jax_module.export_version
52+
== constants.ExportModelType.ORBAX_MODEL
53+
):
5154
self._serialization_functions = obm_export.ObmExport(
5255
self._jax_module, serving_configs
5356
)
@@ -59,7 +62,11 @@ def __init__(
5962
@property
6063
def tf_module(self) -> tf.Module:
6164
"""Returns the tf.module maintained by the export manager."""
62-
if self._version == constants.ExportModelType.ORBAX_MODEL:
65+
if (
66+
not self._jax_module
67+
or self._jax_module.export_version
68+
== constants.ExportModelType.ORBAX_MODEL
69+
):
6370
raise TypeError(
6471
'tf_module is not implemented for export version'
6572
' ExportModelType.ORBAX_MODEL.'

export/orbax/export/export_manager_obm_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
from absl.testing import absltest
2020
from absl.testing import parameterized
21+
2122
from orbax.export import constants
2223
from orbax.export import export_manager
2324
from orbax.export import export_testing_utils
2425
from orbax.export import oex_orchestration
26+
from orbax.export import serving_config as sc
2527
import tensorflow as tf
2628

2729

export/orbax/export/obm_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class ObmExport(export_base.ExportBase):
4343

4444
def __init__(
4545
self,
46-
module: jax_module.JaxModule,
46+
module: jax_module.JaxModule | None,
4747
serving_configs: Sequence[osc.ServingConfig],
4848
):
4949
"""Initializes the ObmExport class."""

0 commit comments

Comments
 (0)