@@ -35,7 +35,7 @@ class ExportManager:
3535
3636 def __init__ (
3737 self ,
38- module : jax_module .JaxModule ,
38+ module : jax_module .JaxModule | None ,
3939 serving_configs : Sequence [osc .ServingConfig ],
4040 ):
4141 """ExportManager constructor.
@@ -45,9 +45,12 @@ def __init__(
4545 serving_configs: a sequence of which each element is a `ServingConfig`
4646 cooresponding to a serving signature of the exported SavedModel.
4747 """
48- self ._version = module .export_version
4948 self ._jax_module = module
50- if self ._version == constants .ExportModelType .ORBAX_MODEL :
49+ if (
50+ not self ._jax_module
51+ or self ._jax_module .export_version
52+ == constants .ExportModelType .ORBAX_MODEL
53+ ):
5154 self ._serialization_functions = obm_export .ObmExport (
5255 self ._jax_module , serving_configs
5356 )
@@ -59,7 +62,11 @@ def __init__(
5962 @property
6063 def tf_module (self ) -> tf .Module :
6164 """Returns the tf.module maintained by the export manager."""
62- if self ._version == constants .ExportModelType .ORBAX_MODEL :
65+ if (
66+ not self ._jax_module
67+ or self ._jax_module .export_version
68+ == constants .ExportModelType .ORBAX_MODEL
69+ ):
6370 raise TypeError (
6471 'tf_module is not implemented for export version'
6572 ' ExportModelType.ORBAX_MODEL.'
0 commit comments