Skip to content

Commit

Permalink
feat: Install Bigframes tensorflow dependencies automatically
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571107540
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 5, 2023
1 parent 5c993d2 commit e58689b
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
serializers_base,
)

from packaging import version

try:
import cloudpickle
except ImportError:
Expand Down Expand Up @@ -125,6 +127,21 @@
_LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/"
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"

# Map tf major.minor version to tfio version from https://pypi.org/project/tensorflow-io/
_TFIO_VERSION_DICT = {
"2.3": "0.16.0", # Align with testing_extra_require: tensorflow >= 2.3.0
"2.4": "0.17.1",
"2.5": "0.19.1",
"2.6": "0.21.0",
"2.7": "0.23.1",
"2.8": "0.25.0",
"2.9": "0.26.0",
"2.10": "0.27.0",
"2.11": "0.31.0",
"2.12": "0.32.0",
"2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13
}


def get_uri_prefix(gcs_uri: str) -> str:
"""Gets the directory of the gcs_uri.
Expand Down Expand Up @@ -1117,20 +1134,24 @@ def serialize(
gcs_path: str,
**kwargs,
) -> str:
# All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet)
# Record the framework in metadata for deserialization
detected_framework = kwargs.get("framework")
BigframeSerializer._metadata.framework = detected_framework
if detected_framework == "torch":
self.register_custom_command("pip install torchdata")
self.register_custom_command("pip install torcharrow")
# All bigframe serializers will convert bigframes.dataframe.DataFrame --> parquet
if not _is_valid_gcs_path(gcs_path):
raise ValueError(f"Invalid gcs path: {gcs_path}")

BigframeSerializer._metadata.dependencies = (
supported_frameworks._get_bigframe_deps()
)

# Record the framework in metadata for deserialization
detected_framework = kwargs.get("framework")
BigframeSerializer._metadata.framework = detected_framework
if detected_framework == "torch":
self.register_custom_command("pip install torchdata")
self.register_custom_command("pip install torcharrow")
elif detected_framework == "tensorflow":
tensorflow_io_dep = "tensorflow-io==" + self._get_tfio_verison()
BigframeSerializer._metadata.dependencies.append(tensorflow_io_dep)

# Check if index.name is default and set index.name if not
if to_serialize.index.name and to_serialize.index.name != "index":
raise ValueError("Index name must be 'index'")
Expand All @@ -1141,6 +1162,17 @@ def serialize(
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
to_serialize.to_parquet(parquet_gcs_path, index=True)

def _get_tfio_verison(self):
major, minor, _ = version.Version(tf.__version__).release
tf_version = f"{major}.{minor}"

if tf_version not in _TFIO_VERSION_DICT:
raise ValueError(
f"Tensorflow version {tf_version} is not supported for Bigframes."
+ " Supported versions: tensorflow >= 2.3.0, <= 2.12.0."
)
return _TFIO_VERSION_DICT[tf_version]

def deserialize(
self, serialized_gcs_path: str, **kwargs
) -> Union[PandasData, BigframesData]:
Expand Down

0 comments on commit e58689b

Please sign in to comment.