Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions export/orbax/export/export_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Mapping
from absl.testing import parameterized
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -144,6 +145,143 @@ def test_get_serving_signatures_orbax_export(self):
with self.assertRaises(NotImplementedError):
em.serving_signatures # pylint: disable=pointless-statement

def test_save_model_with_preprocess_output_passthrough_succeeds(self):
"""Tests that the model can be saved with preprocess output passthrough."""

rng = jax.random.PRNGKey(0)
params = {
'w': jax.random.normal(rng, shape=(8, 8)),
'b': jax.random.normal(rng, shape=(8, 1)),
}
bs = 1

@tf.function
def tf_preprocessor(x_float: tf.Tensor, y_str: tf.Tensor):
# Returns two outputs, one to be passed to the JAX function, the other to
# be passed to the TF postprocessor.
return {'pre_out_float': x_float}, {'pre_out_str': y_str}

def jax_func(
params: Mapping[str, jax.Array], inputs: Mapping[str, jax.Array]
):
outputs = params['w'] @ inputs['pre_out_float'] + params['b']
return {
'jax_out_float': outputs,
}

# The TF postprocessor gets two inputs: the JAX function output and the
# second of the preprocessor output.
@tf.function
def tf_postprocessor(
inputs: Mapping[str, tf.Tensor], inputs_extra: Mapping[str, tf.Tensor]
):
return {
'post_out_float': inputs['jax_out_float'],
'post_out_str': inputs_extra['pre_out_str'],
}

m = jax_module.JaxModule(
params,
jax_func,
)
serving_configs = [
sc.ServingConfig(
'serving_default',
[
tf.TensorSpec((bs, 8, 1), tf.float32, name='x_float'),
tf.TensorSpec((bs, 8, 1), tf.string, name='y_str'),
],
tf_preprocessor=tf_preprocessor,
tf_postprocessor=tf_postprocessor,
preprocess_output_passthrough_enabled=True,
)
]
em = export_manager.ExportManager(m, serving_configs)
em.save(
self._output_dir,
)

x_float = jax.random.normal(rng, shape=(bs, 8, 1))
y_str = tf.constant(['a dummy string'] * bs)
loaded = em.load(self._output_dir)
self.assertAllClose(
loaded.signatures['serving_default'](
x_float=tf.convert_to_tensor(x_float, dtype=tf.float32),
y_str=y_str,
)['post_out_float'],
params['w'] @ x_float + params['b'],
atol=0.05,
rtol=0.2,
)
self.assertAllEqual(
loaded.signatures['serving_default'](
x_float=tf.convert_to_tensor(x_float),
y_str=y_str,
)['post_out_str'],
y_str,
)

def test_save_jax2tf_model_with_preprocess_output_passthrough_raises_error(
self,
):
"""Tests that the model saving with preprocess output passthrough raises error.

The error is raised because the preprocessor output doesn't comply with
the requirements of a tuple of two dicts.
"""
rng = jax.random.PRNGKey(0)
params = {
'w': jax.random.normal(rng, shape=(8, 8)),
'b': jax.random.normal(rng, shape=(8, 1)),
}
bs = 1

@tf.function
def tf_preprocessor(x_float: tf.Tensor, y_str: tf.Tensor):
return {'pre_out_float': x_float, 'pre_out_str': y_str}

def jax_func(
params: Mapping[str, jax.Array], inputs: Mapping[str, jax.Array]
):
outputs = params['w'] @ inputs['pre_out_float'] + params['b']
return {
'jax_out_float': outputs,
}

@tf.function
def tf_postprocessor(
inputs: Mapping[str, tf.Tensor], inputs_extra: Mapping[str, tf.Tensor]
):
return {
'post_out_float': inputs['jax_out_float'],
'post_out_str': inputs_extra['pre_out_str'],
}

m = jax_module.JaxModule(
params,
jax_func,
)
serving_configs = [
sc.ServingConfig(
'serving_default',
[
tf.TensorSpec((bs, 8, 1), tf.float32, name='x_float'),
tf.TensorSpec((bs, 8, 1), tf.string, name='y_str'),
],
tf_preprocessor=tf_preprocessor,
tf_postprocessor=tf_postprocessor,
preprocess_output_passthrough_enabled=True,
)
]
em = export_manager.ExportManager(m, serving_configs)
with self.assertRaisesRegex(
ValueError,
'requiring the preprocessor output to be a tuple of two elements',
):
em.save(
self._output_dir,
)


if __name__ == '__main__':
tf.test.main()
86 changes: 72 additions & 14 deletions export/orbax/export/serving_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,43 @@ class ServingConfig:
# Options passed to the Orbax Model export.
obm_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)

# When set to true, it allows a portion of the preprocessor's outputs to be
# directly passed to the tf_postprocessor, bypassing the JAX function.
#
# The primary use case of this option is to handle preprocessing outputs
# containing string tensor that cannot be passed to JAX function, but required
# by the postprocessor.
#
# Pre-requisites:
# This option requires the preprocessor outputs and postprocess inputs
# to be structured in specific ways:
# - The preprocessor must return two outputs, where the first will be passed
# as the input to the jax function and the second will be passed as the
# second input to the postprocessor.
# - The JAX function must take one input and return one output. The output
# will be passed as the first input to the postprocessor.
# - The postprocessor must take two inputs. The first is the output of the
# JAX function and the second is the second element of the preprocessor
# outputs.
#
# For example:
#
# def tf_preprocessor(x):
# return {'pre_out_to_jax': x}, {'pre_out_to_post': x}
#
# def jax_func(inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
# return {'jax_out_to_post': inputs['pre_out_to_jax']}
#
# def tf_postprocessor(
# inputs: Mapping[str, tf.Tensor],
# inputs_extra: Mapping[str, tf.Tensor],
# ) -> Mapping[str, tf.Tensor]:
# return {
# 'post_out_from_jax': inputs['jax_out_to_post'],
# 'post_out_from_pre': inputs_extra['pre_out_to_post'],
# }
preprocess_output_passthrough_enabled: bool = False

def get_signature_keys(self) -> Sequence[str]:
if isinstance(self.signature_key, str):
return [self.signature_key]
Expand Down Expand Up @@ -158,45 +195,66 @@ def make_inference_fn(infer_step):

def inference_fn(*inputs):
if self.tf_preprocessor:
preprocessed_inputs = preprocessor(*inputs)
preprocessor_outputs = preprocessor(*inputs)
if require_numpy:
preprocessed_inputs = jax.tree_util.tree_map(
lambda x: x.numpy(), preprocessed_inputs
preprocessor_outputs = jax.tree_util.tree_map(
lambda x: x.numpy(), preprocessor_outputs
)
if self.preprocess_output_passthrough_enabled:
if (
not isinstance(preprocessor_outputs, tuple)
or len(preprocessor_outputs) != 2
):
raise ValueError(
'`preprocess_output_passthrough_enabled` is enabled,'
' requiring the preprocessor output to be a tuple of two'
f' elements, but got {preprocessor_outputs} with'
f' type={type(preprocessor_outputs)} and'
f' length={len(preprocessor_outputs)}.'
)
jax_inputs, postprocessor_inputs_extra = preprocessor_outputs
else:
jax_inputs = preprocessor_outputs
else:
preprocessed_inputs = inputs

if len(preprocessed_inputs) != 1:
jax_inputs = inputs
if len(jax_inputs) != 1:
raise ValueError(
'JaxModule only takes single arg as the input, but got'
f' len(inputs)={len(inputs)} from the preprocessor or input'
' signature. Please pack all inputs into one PyTree by'
' modifying the `input_signature` (if no `tf_preprocessor`) or'
' the ServingConfig.tf_preprocessor.'
)

preprocessed_inputs = preprocessed_inputs[0]
jax_inputs = jax_inputs[0]

# Currently Jax Module only takes 1 input
outputs = infer_step(preprocessed_inputs)
jax_outputs = infer_step(jax_inputs)
if logging.vlog_is_on(3) and require_numpy:
if hasattr(infer_step, 'lower'):
lower = infer_step.lower
else:
lower = jax.jit(infer_step).lower

mlir_module_text = lower(
preprocessed_inputs,
jax_inputs,
).as_text()
logging.info(
'Jax function infer_step mlir module: = %s', mlir_module_text
)

if self.tf_postprocessor:
outputs = postprocessor(outputs)
if self.preprocess_output_passthrough_enabled:
postprocessor_outputs = postprocessor(
jax_outputs, postprocessor_inputs_extra
)
else:
postprocessor_outputs = postprocessor(jax_outputs)
if require_numpy:
outputs = jax.tree_util.tree_map(lambda x: x.numpy(), outputs)
return outputs
postprocessor_outputs = jax.tree_util.tree_map(
lambda x: x.numpy(), postprocessor_outputs
)
else:
postprocessor_outputs = jax_outputs
return postprocessor_outputs

return inference_fn

Expand Down
Loading