Skip to content

Commit

Permalink
TypeTransformer for TensorFlow model (#1562)
Browse files Browse the repository at this point in the history
* TypeTransformer for TensorFlow model

Signed-off-by: Samhita Alla <[email protected]>

* clean up

Signed-off-by: Samhita Alla <[email protected]>

* clean up

Signed-off-by: Samhita Alla <[email protected]>

* fix lint

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored Mar 25, 2023
1 parent 53f134a commit ee2714f
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 2 deletions.
3 changes: 2 additions & 1 deletion flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@


if _tensorflow_installed:
from .model import TensorFlowModelTransformer
from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer
else:
logger.info(
"We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer "
"We won't register TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer and TensorFlowModelTransformer"
"because tensorflow is not installed."
)
76 changes: 76 additions & 0 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pathlib
from typing import Type

import tensorflow as tf

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
TENSORFLOW_FORMAT = "TensorFlowModel"

def __init__(self):
super().__init__(name="TensorFlow Model", t=tf.keras.Model)

def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.Model,
python_type: Type[tf.keras.Model],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
)
)

local_path = ctx.file_access.get_random_local_path()
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save model in SavedModel format
tf.keras.models.save_model(python_val, local_path)

remote_path = ctx.file_access.get_random_remote_path()
ctx.file_access.put_data(local_path, remote_path, is_multipart=True)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model]
) -> tf.keras.Model:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=True)

# load model
return tf.keras.models.load_model(local_path)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART
and literal_type.blob.format == self.TENSORFLOW_FORMAT
):
return tf.keras.Model

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(TensorFlowModelTransformer())
1 change: 0 additions & 1 deletion flytekit/extras/tensorflow/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def to_literal(
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory]
) -> TFRecordDatasetV2:

uri, metadata = extract_metadata_and_uri(lv, expected_python_type)
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(uri, local_dir, is_multipart=True)
Expand Down
Empty file.
54 changes: 54 additions & 0 deletions tests/flytekit/unit/extras/tensorflow/model/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import tensorflow as tf

from flytekit import task, workflow


@task
def generate_model() -> tf.keras.Model:
inputs = tf.keras.Input(shape=(32,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, outputs)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[
tf.keras.metrics.BinaryAccuracy(),
],
)
return model


@task
def generate_sequential_model() -> tf.keras.Sequential:
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(32,)),
tf.keras.layers.Dense(1),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[
tf.keras.metrics.BinaryAccuracy(),
],
)
return model


@task
def model_forward_pass(model: tf.keras.Model) -> tf.Tensor:
x: tf.Tensor = tf.ones((1, 32))
return model(x)


@workflow
def wf():
model1 = generate_model()
model2 = generate_sequential_model()
model_forward_pass(model=model1)
model_forward_pass(model=model2)


def test_wf():
wf()
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from collections import OrderedDict

import numpy as np
import pytest
import tensorflow as tf

import flytekit
from flytekit import task
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.extras.tensorflow import TensorFlowModelTransformer
from flytekit.models.core.types import BlobType
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def get_tf_model():
inputs = tf.keras.Input(shape=(32,))
outputs = tf.keras.layers.Dense(1)(inputs)
tf_model = tf.keras.Model(inputs, outputs)
return tf_model


@pytest.mark.parametrize(
"transformer,python_type,format",
[
(TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT),
],
)
def test_get_literal_type(transformer, python_type, format):
lt = transformer.get_literal_type(python_type)
assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART))


@pytest.mark.parametrize(
"transformer,python_type,format,python_val",
[
(TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()),
],
)
def test_to_python_value_and_literal(transformer, python_type, format, python_val):
ctx = context_manager.FlyteContext.current_context()
lt = transformer.get_literal_type(python_type)

lv = transformer.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
output = transformer.to_python_value(ctx, lv, python_type)

assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.MULTIPART,
)
)
assert lv.scalar.blob.uri is not None
for w1, w2 in zip(output.weights, python_val.weights):
np.testing.assert_allclose(w1.numpy(), w2.numpy())


def test_example_model():
@task
def t1() -> tf.keras.Model:
return get_tf_model()

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorFlowModelTransformer.TENSORFLOW_FORMAT

0 comments on commit ee2714f

Please sign in to comment.