diff --git a/CHANGES.md b/CHANGES.md index 7686b7a92d96..60b5a820cf3b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -67,6 +67,7 @@ * Python GCSIO is now implemented with GCP GCS Client instead of apitools ([#25676](https://github.com/apache/beam/issues/25676)) * Adding support for LowCardinality DataType in ClickHouse (Java) ([#29533](https://github.com/apache/beam/pull/29533)). * Added support for handling bad records to KafkaIO (Java) ([#29546](https://github.com/apache/beam/pull/29546)) +* Add support for generating text embeddings in MLTransform for Vertex AI and Hugging Face Hub models.([#29564](https://github.com/apache/beam/pull/29564)) ## New Features / Improvements diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index b3a30bb5f125..d5f4d1b60e14 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -14,20 +14,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pytype: skip-file - import abc +import collections +import logging +import os +import tempfile +import uuid +from typing import Any from typing import Dict from typing import Generic from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import TypeVar +from typing import Union + +import jsonpickle +import numpy as np import apache_beam as beam +from apache_beam.io.filesystems import FileSystems from apache_beam.metrics.metric import Metrics +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import ModelT +from apache_beam.options.pipeline_options import PipelineOptions -__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation'] +_LOGGER = logging.getLogger(__name__) +_ATTRIBUTE_FILE_NAME = 'attributes.json' + +__all__ = [ + 'MLTransform', + 'ProcessHandler', + 'MLTransformProvider', + 'BaseOperation', + 'EmbeddingsManager' +] TransformedDatasetT = TypeVar('TransformedDatasetT') TransformedMetadataT = TypeVar('TransformedMetadataT') @@ -42,12 +64,62 @@ OperationOutputT = TypeVar('OperationOutputT') +def _convert_list_of_dicts_to_dict_of_lists( + list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]: + keys_to_element_list = collections.defaultdict(list) + for d in list_of_dicts: + for key, value in d.items(): + keys_to_element_list[key].append(value) + return keys_to_element_list + + +def _convert_dict_of_lists_to_lists_of_dict( + dict_of_lists: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + batch_length = len(next(iter(dict_of_lists.values()))) + result: List[Dict[str, Any]] = [{} for _ in range(batch_length)] + # all the values in the dict_of_lists should have same length + for key, values in dict_of_lists.items(): + assert len(values) == batch_length, ( + "This function expects all the values " + "in the dict_of_lists to have same length." + ) + for i in range(len(values)): + result[i][key] = values[i] + return result + + class ArtifactMode(object): PRODUCE = 'produce' CONSUME = 'consume' -class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC): +class MLTransformProvider: + """ + Data processing transforms that are intended to be used with MLTransform + should subclass MLTransformProvider and implement + get_ptransform_for_processing(). + + get_ptransform_for_processing() method should return a PTransform that can be + used to process the data. + + """ + @abc.abstractmethod + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + """ + Returns a PTransform that can be used to process the data. + """ + + def get_counter(self): + """ + Returns the counter name for the data processing transform. + """ + counter_name = self.__class__.__name__ + return Metrics.counter(MLTransform, f'BeamML_{counter_name}') + + +class BaseOperation(Generic[OperationInputT, OperationOutputT], + MLTransformProvider, + abc.ABC): def __init__(self, columns: List[str]) -> None: """ Base Opertation class data processing transformations. @@ -76,33 +148,53 @@ def __call__(self, data: OperationInputT, transformed_data = self.apply_transform(data, output_column_name) return transformed_data - def get_counter(self): - """ - Returns the counter name for the operation. - """ - counter_name = self.__class__.__name__ - return Metrics.counter(MLTransform, f'BeamML_{counter_name}') - -class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): +class ProcessHandler(beam.PTransform[beam.PCollection[ExampleT], + beam.PCollection[MLTransformOutputT]], + abc.ABC): """ Only for internal use. No backwards compatibility guarantees. """ @abc.abstractmethod - def process_data( - self, pcoll: beam.PCollection[ExampleT] - ) -> beam.PCollection[MLTransformOutputT]: + def append_transform(self, transform: BaseOperation): """ - Logic to process the data. This will be the entrypoint in - beam.MLTransform to process incoming data. + Append transforms to the ProcessHandler. """ + +# TODO:https://github.com/apache/beam/issues/29356 +# Add support for inference_fn +class EmbeddingsManager(MLTransformProvider): + def __init__( + self, + columns: List[str], + *, + # common args for all ModelHandlers. + load_model_args: Optional[Dict[str, Any]] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + large_model: bool = False, + **kwargs): + self.load_model_args = load_model_args or {} + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + self.large_model = large_model + self.columns = columns + self.inference_args = kwargs.pop('inference_args', {}) + + if kwargs: + _LOGGER.warning("Ignoring the following arguments: %s", kwargs.keys()) + + # TODO:https://github.com/apache/beam/pull/29564 add set_model_handler method @abc.abstractmethod - def append_transform(self, transform: BaseOperation): + def get_model_handler(self) -> ModelHandler: """ - Append transforms to the ProcessHandler. + Return framework specific model handler. """ + def get_columns_to_apply(self): + return self.columns + class MLTransform(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[MLTransformOutputT]], @@ -112,7 +204,7 @@ def __init__( *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, - transforms: Optional[Sequence[BaseOperation]] = None): + transforms: Optional[List[MLTransformProvider]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the @@ -157,9 +249,6 @@ def __init__( i-th transform is the output of the (i-1)-th transform. Multi-input transforms are not supported yet. """ - if transforms: - _ = [self._validate_transform(transform) for transform in transforms] - if read_artifact_location and write_artifact_location: raise ValueError( 'Only one of read_artifact_location or write_artifact_location can ' @@ -177,19 +266,10 @@ def __init__( artifact_location = write_artifact_location # type: ignore[assignment] artifact_mode = ArtifactMode.PRODUCE - # avoid circular import - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam.ml.transforms.handlers import TFTProcessHandler - # TODO: When new ProcessHandlers(eg: JaxProcessHandler) are introduced, - # create a mapping between transforms and ProcessHandler since - # ProcessHandler is not exposed to the user. - process_handler: ProcessHandler = TFTProcessHandler( - artifact_location=artifact_location, - artifact_mode=artifact_mode, - transforms=transforms) # type: ignore[arg-type] - - self._process_handler = process_handler - self.transforms = transforms + self._parent_artifact_location = artifact_location + + self._artifact_mode = artifact_mode + self.transforms = transforms or [] self._counter = Metrics.counter( MLTransform, f'BeamML_{self.__class__.__name__}') @@ -209,12 +289,34 @@ def expand( Returns: A PCollection of MLTransformOutputT type """ + _ = [self._validate_transform(transform) for transform in self.transforms] + if self._artifact_mode == ArtifactMode.PRODUCE: + ptransform_partitioner = _MLTransformToPTransformMapper( + transforms=self.transforms, + artifact_location=self._parent_artifact_location, + artifact_mode=self._artifact_mode, + pipeline_options=pcoll.pipeline.options) + ptransform_list = ptransform_partitioner.create_and_save_ptransform_list() + else: + ptransform_list = ( + _MLTransformToPTransformMapper.load_transforms_from_artifact_location( + self._parent_artifact_location)) + + # the saved transforms has artifact mode set to PRODUCE. + # set the artifact mode to CONSUME. + for i in range(len(ptransform_list)): + if hasattr(ptransform_list[i], 'artifact_mode'): + ptransform_list[i].artifact_mode = self._artifact_mode + + for ptransform in ptransform_list: + pcoll = pcoll | ptransform + _ = ( pcoll.pipeline | "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self)) - return self._process_handler.process_data(pcoll) + return pcoll # type: ignore[return-value] - def with_transform(self, transform: BaseOperation): + def with_transform(self, transform: MLTransformProvider): """ Add a transform to the MLTransform pipeline. Args: @@ -223,11 +325,11 @@ def with_transform(self, transform: BaseOperation): A MLTransform instance. """ self._validate_transform(transform) - self._process_handler.append_transform(transform) + self.transforms.append(transform) return self def _validate_transform(self, transform): - if not isinstance(transform, BaseOperation): + if not isinstance(transform, MLTransformProvider): raise TypeError( 'transform must be a subclass of BaseOperation. ' 'Got: %s instead.' % type(transform)) @@ -243,9 +345,7 @@ def _increment_counters(): # increment for MLTransform. self._ml_transform._counter.inc() # increment if data processing transforms are passed. - transforms = ( - self._ml_transform.transforms or - self._ml_transform._process_handler.transforms) + transforms = self._ml_transform.transforms if transforms: for transform in transforms: transform.get_counter().inc() @@ -254,3 +354,263 @@ def _increment_counters(): pipeline | beam.Create([None]) | beam.Map(lambda _: _increment_counters())) + + +class _TransformAttributeManager: + """ + Base class used for saving and loading the attributes. + """ + @staticmethod + def save_attributes(artifact_location): + """ + Save the attributes to json file using stdlib json. + """ + raise NotImplementedError + + @staticmethod + def load_attributes(artifact_location): + """ + Load the attributes from json file. + """ + raise NotImplementedError + + +class _JsonPickleTransformAttributeManager(_TransformAttributeManager): + """ + Use Jsonpickle to save and load the attributes. Here the attributes refer + to the list of PTransforms that are used to process the data. + + jsonpickle is used to serialize the PTransforms and save it to a json file and + is compatible across python versions. + """ + @staticmethod + def _is_remote_path(path): + is_gcs = path.find('gs://') != -1 + # TODO:https://github.com/apache/beam/issues/29356 + # Add support for other remote paths. + if not is_gcs and path.find('://') != -1: + raise RuntimeError( + "Artifact locations are currently supported for only available for " + "local paths and GCS paths. Got: %s" % path) + return is_gcs + + @staticmethod + def save_attributes( + ptransform_list, + artifact_location, + **kwargs, + ): + # if an artifact location is present, instead of overwriting the + # existing file, raise an error since the same artifact location + # can be used by multiple beam jobs and this could result in undesired + # behavior. + if FileSystems.exists(FileSystems.join(artifact_location, + _ATTRIBUTE_FILE_NAME)): + raise FileExistsError( + "The artifact location %s already exists and contains %s. Please " + "specify a different location." % + (artifact_location, _ATTRIBUTE_FILE_NAME)) + + if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location): + temp_dir = tempfile.mkdtemp() + temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME) + with open(temp_json_file, 'w+') as f: + f.write(jsonpickle.encode(ptransform_list)) + with open(temp_json_file, 'rb') as f: + from apache_beam.runners.dataflow.internal import apiclient + _LOGGER.info('Creating artifact location: %s', artifact_location) + # pipeline options required to for the client to configure project. + options = kwargs.get('options') + try: + apiclient.DataflowApplicationClient(options=options).stage_file( + gcs_or_local_path=artifact_location, + file_name=_ATTRIBUTE_FILE_NAME, + stream=f, + mime_type='application/json') + except Exception as exc: + if not options: + raise RuntimeError( + "Failed to create Dataflow client. " + "Pipeline options are required to save the attributes." + "in the artifact location %s" % artifact_location) from exc + raise + else: + if not FileSystems.exists(artifact_location): + FileSystems.mkdirs(artifact_location) + # FileSystems.open() fails if the file does not exist. + with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME), + 'w+') as f: + f.write(jsonpickle.encode(ptransform_list)) + + @staticmethod + def load_attributes(artifact_location): + with FileSystems.open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME), + 'rb') as f: + return jsonpickle.decode(f.read()) + + +_transform_attribute_manager = _JsonPickleTransformAttributeManager + + +class _MLTransformToPTransformMapper: + """ + This class takes in a list of data processing transforms compatible to be + wrapped around MLTransform and returns a list of PTransforms that are used to + run the data processing transforms. + + The _MLTransformToPTransformMapper is responsible for loading and saving the + PTransforms or attributes of PTransforms to the artifact location to seal + the gap between the training and inference pipelines. + """ + def __init__( + self, + transforms: List[MLTransformProvider], + artifact_location: str, + artifact_mode: str = ArtifactMode.PRODUCE, + pipeline_options: Optional[PipelineOptions] = None, + ): + self.transforms = transforms + self._parent_artifact_location = artifact_location + self.artifact_mode = artifact_mode + self.pipeline_options = pipeline_options + + def create_and_save_ptransform_list(self): + ptransform_list = self.create_ptransform_list() + self.save_transforms_in_artifact_location(ptransform_list) + return ptransform_list + + def create_ptransform_list(self): + previous_ptransform_type = None + current_ptransform = None + ptransform_list = [] + for transform in self.transforms: + if not isinstance(transform, MLTransformProvider): + raise RuntimeError( + 'Transforms must be instances of MLTransformProvider and ' + 'implement get_ptransform_for_processing() method.') + # for each instance of PTransform, create a new artifact location + current_ptransform = transform.get_ptransform_for_processing( + artifact_location=os.path.join( + self._parent_artifact_location, uuid.uuid4().hex[:6]), + artifact_mode=self.artifact_mode) + append_transform = hasattr(current_ptransform, 'append_transform') + if (type(current_ptransform) != + previous_ptransform_type) or not append_transform: + ptransform_list.append(current_ptransform) + previous_ptransform_type = type(current_ptransform) + # If different PTransform is appended to the list and the PTransform + # supports append_transform, append the transform to the PTransform. + if append_transform: + ptransform_list[-1].append_transform(transform) + return ptransform_list + + def save_transforms_in_artifact_location(self, ptransform_list): + """ + Save the ptransform references to json file. + """ + _transform_attribute_manager.save_attributes( + ptransform_list=ptransform_list, + artifact_location=self._parent_artifact_location, + options=self.pipeline_options) + + @staticmethod + def load_transforms_from_artifact_location(artifact_location): + return _transform_attribute_manager.load_attributes(artifact_location) + + +class _TextEmbeddingHandler(ModelHandler): + """ + A ModelHandler intended to be work on list[dict[str, str]] inputs. + + The inputs to the model handler are expected to be a list of dicts. + + For example, if the original mode is used with RunInference to take a + PCollection[E] to a PCollection[P], this ModelHandler would take a + PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + + _TextEmbeddingHandler will accept an EmbeddingsManager instance, which + contains the details of the model to be loaded and the inference_fn to be + used. The purpose of _TextEmbeddingHandler is to generate embeddings for + text inputs using the EmbeddingsManager instance. + + If the input is not a text column, a RuntimeError will be raised. + + This is an internal class and offers no backwards compatibility guarantees. + + Args: + embeddings_manager: An EmbeddingsManager instance. + """ + def __init__(self, embeddings_manager: EmbeddingsManager): + self.embedding_config = embeddings_manager + self._underlying = self.embedding_config.get_model_handler() + self.columns = self.embedding_config.get_columns_to_apply() + + def load_model(self): + model = self._underlying.load_model() + return model + + def _validate_column_data(self, batch): + if not isinstance(batch[0], (str, bytes)): + raise TypeError( + 'Embeddings can only be generated on Dict[str, str].' + f'Got Dict[str, {type(batch[0])}] instead.') + + def _validate_batch(self, batch: Sequence[Dict[str, List[str]]]): + if not batch or not isinstance(batch[0], dict): + raise TypeError( + 'Expected data to be dicts, got ' + f'{type(batch[0])} instead.') + + def _process_batch( + self, + dict_batch: Dict[str, List[Any]], + model: ModelT, + inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]: + result: Dict[str, List[Any]] = collections.defaultdict(list) + for key, batch in dict_batch.items(): + if key in self.columns: + self._validate_column_data(batch) + prediction = self._underlying.run_inference( + batch, model, inference_args) + if isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + result[key] = prediction # type: ignore[assignment] + else: + result[key] = prediction # type: ignore[assignment] + else: + result[key] = batch + return result + + def run_inference( + self, + batch: Sequence[Dict[str, List[str]]], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Union[List[float], List[str]]]]: + """ + Runs inference on a batch of text inputs. The inputs are expected to be + a list of dicts. Each dict should have the same keys, and the shape + should be of the same size for a single key across the batch. + """ + self._validate_batch(batch) + dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch) + transformed_batch = self._process_batch(dict_batch, model, inference_args) + return _convert_dict_of_lists_to_lists_of_dict( + dict_of_lists=transformed_batch, + ) + + def get_metrics_namespace(self) -> str: + return ( + self._underlying.get_metrics_namespace() or + 'BeamML_TextEmbeddingHandler') + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + batch_sizes_map = {} + if self.embedding_config.max_batch_size: + batch_sizes_map['max_batch_size'] = self.embedding_config.max_batch_size + if self.embedding_config.min_batch_size: + batch_sizes_map['min_batch_size'] = self.embedding_config.min_batch_size + return (self._underlying.batch_elements_kwargs() or batch_sizes_map) + + def validate_inference_args(self, _): + pass diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 2e447964541b..e07959436198 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -16,11 +16,16 @@ # # pytype: skip-file +import os import shutil import tempfile import typing import unittest +from typing import Any +from typing import Dict from typing import List +from typing import Optional +from typing import Sequence import numpy as np from parameterized import param @@ -28,28 +33,36 @@ import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms import base from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: - from apache_beam.ml.transforms import base from apache_beam.ml.transforms import tft + from apache_beam.ml.transforms.handlers import TFTProcessHandler from apache_beam.ml.transforms.tft import TFTOperation except ImportError: tft = None # type: ignore -if tft is None: - raise unittest.SkipTest('tensorflow_transform is not installed') +try: + class _FakeOperation(TFTOperation): + def __init__(self, name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name -class _FakeOperation(TFTOperation): - def __init__(self, name, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = name + def apply_transform(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs} +except: # pylint: disable=bare-except + pass - def apply_transform(self, inputs, output_column_name, **kwargs): - return {output_column_name: inputs} +try: + from apache_beam.runners.dataflow.internal import apiclient +except ImportError: + apiclient = None # type: ignore class BaseMLTransformTest(unittest.TestCase): @@ -59,6 +72,7 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_appends_transforms_to_process_handler_correctly(self): fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) transforms = [fake_fn_1] @@ -67,12 +81,11 @@ def test_ml_transform_appends_transforms_to_process_handler_correctly(self): ml_transform = ml_transform.with_transform( transform=_FakeOperation(name='fake_fn_2', columns=['x'])) - self.assertEqual(len(ml_transform._process_handler.transforms), 2) - self.assertEqual( - ml_transform._process_handler.transforms[0].name, 'fake_fn_1') - self.assertEqual( - ml_transform._process_handler.transforms[1].name, 'fake_fn_2') + self.assertEqual(len(ml_transform.transforms), 2) + self.assertEqual(ml_transform.transforms[0].name, 'fake_fn_1') + self.assertEqual(ml_transform.transforms[1].name, 'fake_fn_2') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] data = [{'x': 1}, {'x': 2}] @@ -91,6 +104,7 @@ def test_ml_transform_on_dict(self): assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_list_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] @@ -162,6 +176,7 @@ def test_ml_transform_on_list_dict(self): }, ), ]) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_dict_output_pcoll_schema( self, input_data, input_types, expected_dtype): transforms = [tft.ScaleTo01(columns=['x'])] @@ -178,6 +193,7 @@ def test_ml_transform_dict_output_pcoll_schema( if name in expected_dtype: self.assertEqual(expected_dtype[name], typ) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): transforms = [tft.ScaleTo01(columns=['x'])] with beam.Pipeline() as p: @@ -193,6 +209,7 @@ def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): write_artifact_location=self.artifact_location, )) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_multiple_columns_single_transform(self): transforms = [tft.ScaleTo01(columns=['x', 'y'])] data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] @@ -217,6 +234,7 @@ def test_ml_transform_on_multiple_columns_single_transform(self): equal_to(expected_output_y, equals_fn=np.array_equal), label='y') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transforms_on_multiple_columns_multiple_transforms(self): transforms = [ tft.ScaleTo01(columns=['x']), @@ -245,6 +263,7 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self): equal_to(expected_output_y, equals_fn=np.array_equal), label='actual_output_y') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_mltransform_with_counter(self): transforms = [ tft.ComputeAndApplyVocabulary(columns=['y']), @@ -269,6 +288,298 @@ def test_mltransform_with_counter(self): self.assertEqual( result.metrics().query(mltransform_counter)['counters'][0].result, 1) + def test_non_ptransfrom_provider_class_to_mltransform(self): + class Add: + def __call__(self, x): + return x + 1 + + with self.assertRaisesRegex(TypeError, 'transform must be a subclass of'): + with beam.Pipeline() as p: + _ = ( + p + | beam.Create([{ + 'x': 1 + }]) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + Add())) + + +class FakeModel: + def __call__(self, example: List[str]) -> List[str]: + for i in range(len(example)): + example[i] = example[i][::-1] + return example + + +class FakeModelHandler(ModelHandler): + def run_inference( + self, + batch: Sequence[str], + model: Any, + inference_args: Optional[Dict[str, Any]] = None): + return model(batch) + + def load_model(self): + return FakeModel() + + +class FakeEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, columns): + super().__init__(columns=columns) + + def get_model_handler(self) -> ModelHandler: + return FakeModelHandler() + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=base._TextEmbeddingHandler(self))) + + +class TextEmbeddingHandlerTest(unittest.TestCase): + def setUp(self) -> None: + self.embedding_conig = FakeEmbeddingsManager(columns=['x']) + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_handler_with_incompatible_datatype(self): + text_handler = base._TextEmbeddingHandler( + embeddings_manager=self.embedding_conig) + data = [ + ('x', 1), + ('x', 2), + ('x', 3), + ] + with self.assertRaises(TypeError): + text_handler.run_inference(data, None, None) + + def test_handler_with_dict_inputs(self): + data = [ + { + 'x': "Hello world" + }, + { + 'x': "Apache Beam" + }, + ] + expected_data = [{key: value[::-1] + for key, value in d.items()} for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + + def test_handler_with_batch_sizes(self): + self.embedding_conig.max_batch_size = 100 + self.embedding_conig.min_batch_size = 10 + data = [ + { + 'x': "Hello world" + }, + { + 'x': "Apache Beam" + }, + ] * 100 + expected_data = [{key: value[::-1] + for key, value in d.items()} for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + + def test_handler_on_multiple_columns(self): + self.embedding_conig.columns = ['x', 'y'] + data = [ + { + 'x': "Hello world", 'y': "Apache Beam", 'z': 'unchanged' + }, + { + 'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged' + }, + ] + self.embedding_conig.columns = ['x', 'y'] + expected_data = [{ + key: (value[::-1] if key in self.embedding_conig.columns else value) + for key, + value in d.items() + } for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + + def test_handler_with_list_data(self): + data = [{ + 'x': ['Hello world', 'Apache Beam'], + }, { + 'x': ['Apache Beam', 'Hello world'], + }] + with self.assertRaises(TypeError): + with beam.Pipeline() as p: + _ = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + + +class TestUtilFunctions(unittest.TestCase): + def test_list_of_dicts_to_dict_of_lists_normal(self): + input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] + expected_output = {'a': [1, 3], 'b': [2, 4]} + self.assertEqual( + base._convert_list_of_dicts_to_dict_of_lists(input_list), + expected_output) + + def test_list_of_dicts_to_dict_of_lists_on_list_inputs(self): + input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] + expected_output = {'a': [[1, 2, 10], [1]], 'b': [3, 5]} + self.assertEqual( + base._convert_list_of_dicts_to_dict_of_lists(input_list), + expected_output) + + def test_dict_of_lists_to_lists_of_dict_normal(self): + input_dict = {'a': [1, 3], 'b': [2, 4]} + expected_output = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] + self.assertEqual( + base._convert_dict_of_lists_to_lists_of_dict(input_dict), + expected_output) + + def test_dict_of_lists_to_lists_of_dict_unequal_length(self): + input_dict = {'a': [1, 3], 'b': [2]} + with self.assertRaises(AssertionError): + base._convert_dict_of_lists_to_lists_of_dict(input_dict) + + +class TestJsonPickleTransformAttributeManager(unittest.TestCase): + def setUp(self): + self.attribute_manager = base._transform_attribute_manager + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + @unittest.skipIf(tft is None, 'tft module is not installed.') + def test_save_tft_process_handler(self): + transforms = [ + tft.ScaleTo01(columns=['x']), + tft.ComputeAndApplyVocabulary(columns=['y']) + ] + process_handler = TFTProcessHandler( + transforms=transforms, + artifact_location=self.artifact_location, + ) + self.attribute_manager.save_attributes( + ptransform_list=[process_handler], + artifact_location=self.artifact_location, + ) + + files = os.listdir(self.artifact_location) + self.assertTrue(len(files) == 1) + self.assertTrue(files[0] == base._ATTRIBUTE_FILE_NAME) + + def test_save_run_inference(self): + self.attribute_manager.save_attributes( + ptransform_list=[RunInference(model_handler=FakeModelHandler())], + artifact_location=self.artifact_location, + ) + files = os.listdir(self.artifact_location) + self.assertTrue(len(files) == 1) + self.assertTrue(files[0] == base._ATTRIBUTE_FILE_NAME) + + def test_save_and_load_run_inference(self): + ptransform_list = [RunInference(model_handler=FakeModelHandler())] + self.attribute_manager.save_attributes( + ptransform_list=ptransform_list, + artifact_location=self.artifact_location, + ) + loaded_ptransform_list = self.attribute_manager.load_attributes( + artifact_location=self.artifact_location, + ) + + self.assertTrue(len(loaded_ptransform_list) == len(ptransform_list)) + self.assertListEqual( + list(loaded_ptransform_list[0].__dict__.keys()), + list(ptransform_list[0].__dict__.keys())) + + get_keys = lambda x: list(x.__dict__.keys()) + for i, transform in enumerate(ptransform_list): + self.assertListEqual( + get_keys(transform), get_keys(loaded_ptransform_list[i])) + if hasattr(transform, 'model_handler'): + model_handler = transform.model_handler + loaded_model_handler = loaded_ptransform_list[i].model_handler + self.assertListEqual( + get_keys(model_handler), get_keys(loaded_model_handler)) + + def test_mltransform_to_ptransform_wrapper(self): + transforms = [ + FakeEmbeddingsManager(columns=['x']), + FakeEmbeddingsManager(columns=['y', 'z']), + ] + ptransform_mapper = base._MLTransformToPTransformMapper( + transforms=transforms, + artifact_location=self.artifact_location, + artifact_mode=None) + + ptransform_list = ptransform_mapper.create_ptransform_list() + self.assertTrue(len(ptransform_list) == 2) + + self.assertEqual(type(ptransform_list[0]), RunInference) + expected_columns = [['x'], ['y', 'z']] + for i in range(len(ptransform_list)): + self.assertEqual(type(ptransform_list[i]), RunInference) + self.assertEqual( + type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) + self.assertEqual( + ptransform_list[i]._model_handler.columns, expected_columns[i]) + + @unittest.skipIf(apiclient is None, 'apache_beam[gcp] is not installed.') + def test_with_gcs_location_with_none_options(self): + path = 'gs://fake_path' + with self.assertRaises(RuntimeError): + self.attribute_manager.save_attributes( + ptransform_list=[], artifact_location=path, options=None) + with self.assertRaises(RuntimeError): + self.attribute_manager.save_attributes( + ptransform_list=[], artifact_location=path) + + def test_with_same_local_artifact_location(self): + artifact_location = self.artifact_location + attribute_manager = base._JsonPickleTransformAttributeManager() + + ptransform_list = [RunInference(model_handler=FakeModelHandler())] + + attribute_manager.save_attributes( + ptransform_list, artifact_location=artifact_location) + + with self.assertRaises(FileExistsError): + attribute_manager.save_attributes([lambda x: x], + artifact_location=artifact_location) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py b/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py new file mode 100644 index 000000000000..bda6256b79ef --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# TODO: Add dead letter queue for RunInference transforms. + +""" +This module contains embedding configs that can be used to generate +embeddings using MLTransform. +""" diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py new file mode 100644 index 000000000000..e979296b0b83 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["SentenceTransformerEmbeddings"] + +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence + +import apache_beam as beam +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from sentence_transformers import SentenceTransformer + + +# TODO: https://github.com/apache/beam/issues/29621 +# Use HuggingFaceModelHandlerTensor once the import issue is fixed. +# Right now, the hugging face model handler import torch and tensorflow +# at the same time, which adds too much weigth to the container unnecessarily. +class _SentenceTransformerModelHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + model_class: Callable, + load_model_args: Optional[dict] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + large_model: bool = False, + **kwargs): + self._max_seq_length = max_seq_length + self.model_name = model_name + self._model_class = model_class + self._load_model_args = load_model_args + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._large_model = large_model + self._kwargs = kwargs + + def run_inference( + self, + batch: Sequence[str], + model: SentenceTransformer, + inference_args: Optional[Dict[str, Any]] = None, + ): + inference_args = inference_args or {} + return model.encode(batch, **inference_args) + + def load_model(self): + model = self._model_class(self.model_name, **self._load_model_args) + if self._max_seq_length: + model.max_seq_length = self._max_seq_length + return model + + def share_model_across_processes(self) -> bool: + return self._large_model + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + batch_sizes = {} + if self._min_batch_size: + batch_sizes["min_batch_size"] = self._min_batch_size + if self._max_batch_size: + batch_sizes["max_batch_size"] = self._max_batch_size + return batch_sizes + + +class SentenceTransformerEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + columns: List[str], + max_seq_length: Optional[int] = None, + **kwargs): + """ + Embedding config for sentence-transformers. This config can be used with + MLTransform to embed text data. Models are loaded using the RunInference + PTransform with the help of ModelHandler. + + Args: + model_name: Name of the model to use. The model should be hosted on + HuggingFace Hub or compatible with sentence_transformers. + columns: List of columns to be embedded. + max_seq_length: Max sequence length to use for the model if applicable. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns, **kwargs) + self.model_name = model_name + self.max_seq_length = max_seq_length + + def get_model_handler(self): + return _SentenceTransformerModelHandler( + model_class=SentenceTransformer, + max_seq_length=self.max_seq_length, + model_name=self.model_name, + load_model_args=self.load_model_args, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + # wrap the model handler in a _TextEmbeddingHandler since + # the SentenceTransformerEmbeddings works on text input data. + return ( + RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args, + )) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py new file mode 100644 index 000000000000..779a6daf8f3c --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import uuid + +import numpy as np +from parameterized import parameterized + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms import base +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings + import torch +except ImportError: + SentenceTransformerEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + +test_query = "This is a test" +test_query_column = "feature_1" +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +_parameterized_inputs = [ + ([{ + test_query_column: 'That is a happy person' + }, { + test_query_column: 'That is a very happy person' + }], + 'thenlper/gte-base', [0.11, 0.11]), + ([{ + test_query_column: test_query, + }], DEFAULT_MODEL_NAME, [0.13]), + ( + [{ + test_query_column: 'query: how much protein should a female eat', + }, + { + test_query_column: ( + "passage: As a general guideline, the CDC's " + "average requirement of protein for women " + "ages 19 to 70 is 46 grams per day. But, " + "as you can see from this chart, you'll need " + "to increase that if you're expecting or training" + " for a marathon. Check out the chart below " + "to see how much protein " + "you should be eating each day.") + }], + 'intfloat/e5-base-v2', + # this model requires inputs to be specified as query: and passage: + [0.1, 0.1]), +] + + +@unittest.skipIf( + SentenceTransformerEmbeddings is None, + 'sentence-transformers is not installed.') +class SentenceTrasformerEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') + # this bucket has TTL and will be deleted periodically + self.gcs_artifact_location = os.path.join( + 'gs://temp-storage-for-perf-tests/sentence_transformers', + uuid.uuid4().hex) + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_sentence_transformer_embeddings(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + result_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 768 + + _ = (result_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element.feature_1) == 1 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + @parameterized.expand(_parameterized_inputs) + def test_embeddings_with_read_artifact_location( + self, inputs, model_name, output): + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + result_pcoll = ( + p + | "CreateData" >> beam.Create(inputs) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + max_ele_pcoll = ( + result_pcoll + | beam.Map(lambda x: round(max(x[test_query_column]), 2))) + + assert_that(max_ele_pcoll, equal_to(output)) + + with beam.Pipeline() as p: + result_pcoll = ( + p + | "CreateData" >> beam.Create(inputs) + | "MLTransform" >> + MLTransform(read_artifact_location=self.artifact_location)) + max_ele_pcoll = ( + result_pcoll + | beam.Map(lambda x: round(max(x[test_query_column]), 2))) + + assert_that(max_ele_pcoll, equal_to(output)) + + def test_sentence_transformer_with_int_data_types(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + @parameterized.expand(_parameterized_inputs) + def test_with_gcs_artifact_location(self, inputs, model_name, output): + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + result_pcoll = ( + p + | "CreateData" >> beam.Create(inputs) + | "MLTransform" >> + MLTransform(write_artifact_location=self.gcs_artifact_location + ).with_transform(embedding_config)) + max_ele_pcoll = ( + result_pcoll + | beam.Map(lambda x: round(np.max(x[test_query_column]), 2))) + + assert_that(max_ele_pcoll, equal_to(output)) + + with beam.Pipeline() as p: + result_pcoll = ( + p + | "CreateData" >> beam.Create(inputs) + | "MLTransform" >> + MLTransform(read_artifact_location=self.gcs_artifact_location)) + max_ele_pcoll = ( + result_pcoll + | beam.Map(lambda x: round(np.max(x[test_query_column]), 2))) + + assert_that(max_ele_pcoll, equal_to(output)) + + def test_embeddings_with_inference_args(self): + model_name = DEFAULT_MODEL_NAME + + inference_args = {'convert_to_numpy': False} + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, + columns=[test_query_column], + inference_args=inference_args) + with beam.Pipeline() as pipeline: + result_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert type(element) == torch.Tensor + + _ = ( + result_pcoll + | beam.Map(lambda x: x[test_query_column]) + | beam.Map(assert_element)) + + def test_mltransform_to_ptransform_with_sentence_transformer(self): + model_name = '' + transforms = [ + SentenceTransformerEmbeddings(columns=['x'], model_name=model_name), + SentenceTransformerEmbeddings( + columns=['y', 'z'], model_name=model_name) + ] + ptransform_mapper = base._MLTransformToPTransformMapper( + transforms=transforms, + artifact_location=self.artifact_location, + artifact_mode=None) + + ptransform_list = ptransform_mapper.create_and_save_ptransform_list() + self.assertTrue(len(ptransform_list) == 2) + + self.assertEqual(type(ptransform_list[0]), RunInference) + expected_columns = [['x'], ['y', 'z']] + for i in range(len(ptransform_list)): + self.assertEqual(type(ptransform_list[i]), RunInference) + self.assertEqual( + type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) + self.assertEqual( + ptransform_list[i]._model_handler.columns, expected_columns[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.model_name, model_name) + ptransform_list = ( + base._MLTransformToPTransformMapper. + load_transforms_from_artifact_location(self.artifact_location)) + for i in range(len(ptransform_list)): + self.assertEqual(type(ptransform_list[i]), RunInference) + self.assertEqual( + type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) + self.assertEqual( + ptransform_list[i]._model_handler.columns, expected_columns[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.model_name, model_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py new file mode 100644 index 000000000000..1f4c1577eb79 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence + +from google.auth.credentials import Credentials + +import apache_beam as beam +import vertexai +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from vertexai.language_models import TextEmbeddingInput +from vertexai.language_models import TextEmbeddingModel + +__all__ = ["VertexAITextEmbeddings"] + +DEFAULT_TASK_TYPE = "RETRIEVAL_DOCUMENT" +# TODO: https://github.com/apache/beam/issues/29356 +# Can this list be automatically pulled from Vertex SDK? +TASK_TYPE_INPUTS = [ + "RETRIEVAL_DOCUMENT", + "RETRIEVAL_QUERY", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING" +] +_BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. + + +class _VertexAITextEmbeddingHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + title: Optional[str] = None, + task_type: str = DEFAULT_TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + ): + vertexai.init(project=project, location=location, credentials=credentials) + self.model_name = model_name + if task_type not in TASK_TYPE_INPUTS: + raise ValueError( + f"task_type must be one of {TASK_TYPE_INPUTS}, got {task_type}") + self.task_type = task_type + self.title = title + + def run_inference( + self, + batch: Sequence[str], + model: Any, + inference_args: Optional[Dict[str, Any]] = None, + ) -> Iterable: + embeddings = [] + batch_size = _BATCH_SIZE + for i in range(0, len(batch), batch_size): + text_batch = batch[i:i + batch_size] + text_batch = [ + TextEmbeddingInput( + text=text, title=self.title, task_type=self.task_type) + for text in text_batch + ] + embeddings_batch = model.get_embeddings(text_batch) + embeddings.extend([el.values for el in embeddings_batch]) + return embeddings + + def load_model(self): + model = TextEmbeddingModel.from_pretrained(self.model_name) + return model + + +class VertexAITextEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + columns: List[str], + title: Optional[str] = None, + task_type: str = DEFAULT_TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + **kwargs): + """ + Embedding Config for Vertex AI Text Embedding models following + https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings # pylint: disable=line-too-long + Text Embeddings are generated for a batch of text using the Vertex AI SDK. + Embeddings are returned in a list for each text in the batch. Look at + https://cloud.google.com/vertex-ai/docs/generative-ai/learn/model-versioning#stable-versions-available.md # pylint: disable=line-too-long + for more information on model versions and lifecycle. + + Args: + model_name: The name of the Vertex AI Text Embedding model. + columns: The columns containing the text to be embedded. + task_type: The downstream task for the embeddings. Valid values are + RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT, SEMANTIC_SIMILARITY, + CLASSIFICATION, CLUSTERING. For more information on the task type, + look at https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings # pylint: disable=line-too-long + title: Identifier of the text content. + project: The default GCP project for API calls. + location: The default location for API calls. + credentials: Custom credentials for API calls. + Defaults to environment credentials. + """ + self.model_name = model_name + self.project = project + self.location = location + self.credentials = credentials + self.title = title + self.task_type = task_type + super().__init__(columns=columns, **kwargs) + + def get_model_handler(self) -> ModelHandler: + return _VertexAITextEmbeddingHandler( + model_name=self.model_name, + project=self.project, + location=self.location, + credentials=self.credentials, + title=self.title, + task_type=self.task_type, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return ( + RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args)) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py new file mode 100644 index 000000000000..04a730eaefb0 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py @@ -0,0 +1,249 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import uuid + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms import base +from apache_beam.ml.transforms.base import MLTransform + +try: + from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings +except ImportError: + VertexAITextEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + +test_query = "This is a test" +test_query_column = "feature_1" +model_name: str = "textembedding-gecko@002" + + +@unittest.skipIf( + VertexAITextEmbeddings is None, 'Vertex AI Python SDK is not installed.') +class VertexAIEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp(prefix='_vertex_ai_test') + self.gcs_artifact_location = os.path.join( + 'gs://temp-storage-for-perf-tests/vertex_ai', uuid.uuid4().hex) + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_vertex_ai_text_embeddings(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 768 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element.feature_1) == 1 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + def pipeline_with_configurable_artifact_location( + self, + pipeline, + embedding_config=None, + read_artifact_location=None, + write_artifact_location=None): + if write_artifact_location: + return ( + pipeline + | MLTransform(write_artifact_location=write_artifact_location). + with_transform(embedding_config)) + elif read_artifact_location: + return ( + pipeline + | MLTransform(read_artifact_location=read_artifact_location)) + else: + raise NotImplementedError + + def test_embeddings_with_read_artifact_location(self): + with beam.Pipeline() as p: + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.15 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + def test_with_int_data_types(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def test_with_gcs_artifact_location(self): + with beam.Pipeline() as p: + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.gcs_artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.gcs_artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.15 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + def test_mltransform_to_ptransform_with_vertex(self): + model_name = 'textembedding-gecko@002' + transforms = [ + VertexAITextEmbeddings( + columns=['x'], + model_name=model_name, + task_type='RETRIEVAL_DOCUMENT'), + VertexAITextEmbeddings( + columns=['y', 'z'], model_name=model_name, task_type='CLUSTERING') + ] + ptransform_mapper = base._MLTransformToPTransformMapper( + transforms=transforms, + artifact_location=self.artifact_location, + artifact_mode=None) + + ptransform_list = ptransform_mapper.create_and_save_ptransform_list() + self.assertTrue(len(ptransform_list) == 2) + + self.assertEqual(type(ptransform_list[0]), RunInference) + expected_columns = [['x'], ['y', 'z']] + expected_task_type = ['RETRIEVAL_DOCUMENT', 'CLUSTERING'] + for i in range(len(ptransform_list)): + self.assertEqual(type(ptransform_list[i]), RunInference) + self.assertEqual( + type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) + self.assertEqual( + ptransform_list[i]._model_handler.columns, expected_columns[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.task_type, + expected_task_type[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.model_name, model_name) + ptransform_list = ( + base._MLTransformToPTransformMapper. + load_transforms_from_artifact_location(self.artifact_location)) + for i in range(len(ptransform_list)): + self.assertEqual(type(ptransform_list[i]), RunInference) + self.assertEqual( + type(ptransform_list[i]._model_handler), base._TextEmbeddingHandler) + self.assertEqual( + ptransform_list[i]._model_handler.columns, expected_columns[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.task_type, + expected_task_type[i]) + self.assertEqual( + ptransform_list[i]._model_handler._underlying.model_name, model_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index e7d4f52ded85..db6ca849a625 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -388,7 +388,7 @@ def _get_transformed_data_schema( transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] return transformed_types - def process_data( + def expand( self, raw_data: beam.PCollection[tft_process_handler_input_type] ) -> beam.PCollection[tft_process_handler_output_type]: """ @@ -513,7 +513,7 @@ def process_data( # The schema only contains the columns that are transformed. transformed_dataset = ( - transformed_dataset | "ConvertToRowType" >> + transformed_dataset + | "ConvertToRowType" >> beam.Map(lambda x: beam.Row(**x)).with_output_types(row_type)) - return transformed_dataset diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index d67d8ec3e705..f13a916824c4 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -298,7 +298,7 @@ def test_tft_process_handler_verify_artifacts(self): transforms=[tft.ScaleTo01(columns=['x'])], artifact_location=self.artifact_location, ) - _ = process_handler.process_data(raw_data) + _ = raw_data | process_handler self.assertTrue( os.path.exists( @@ -315,7 +315,7 @@ def test_tft_process_handler_verify_artifacts(self): raw_data = (p | beam.Create([{'x': np.array([2, 5])}])) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location, artifact_mode='consume') - transformed_data = process_handler.process_data(raw_data) + transformed_data = raw_data | process_handler transformed_data |= beam.Map(lambda x: x.x) # the previous min is 1 and max is 6. So this should scale by (1, 6) @@ -494,7 +494,7 @@ def test_tft_process_handler_unused_column(self): transforms=[scale_to_0_1_fn], artifact_location=self.artifact_location, ) - transformed_pcoll = process_handler.process_data(raw_data) + transformed_pcoll = raw_data | process_handler transformed_pcoll_x = transformed_pcoll | beam.Map(lambda x: x.x) transformed_pcoll_y = transformed_pcoll | beam.Map(lambda x: x.y) assert_that( @@ -520,7 +520,7 @@ def test_consume_mode_with_extra_columns_in_the_input(self): transforms=[tft.ScaleTo01(columns=['x'])], artifact_location=self.artifact_location, ) - _ = process_handler.process_data(raw_data) + _ = raw_data | process_handler test_data = [{ 'x': np.array([2, 5]), 'y': np.array([1, 2]), 'z': 'fake_string' @@ -548,7 +548,7 @@ def test_consume_mode_with_extra_columns_in_the_input(self): raw_data = (p | beam.Create(test_data)) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location, artifact_mode='consume') - transformed_data = process_handler.process_data(raw_data) + transformed_data = raw_data | process_handler transformed_data_x = transformed_data | beam.Map(lambda x: x.x) transformed_data_y = transformed_data | beam.Map(lambda x: x.y) @@ -596,7 +596,7 @@ def test_handler_with_same_input_elements(self): transforms=[tft.ComputeAndApplyVocabulary(columns=['x'])], artifact_location=self.artifact_location, ) - transformed_data = process_handler.process_data(raw_data) + transformed_data = raw_data | process_handler expected_data = [ beam.Row(x=np.array([4])), diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index c7b8ff015324..8b571d9a685e 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -42,6 +42,7 @@ from typing import Tuple from typing import Union +import apache_beam as beam import tensorflow as tf import tensorflow_transform as tft from apache_beam.ml.transforms.base import BaseOperation @@ -95,6 +96,20 @@ def __init__(self, columns: List[str]) -> None: "Columns are not specified. Please specify the column for the " " op %s" % self.__class__.__name__) + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + from apache_beam.ml.transforms.handlers import TFTProcessHandler + params = {} + artifact_location = kwargs.get('artifact_location') + if not artifact_location: + raise RuntimeError( + "artifact_location is not specified. Please specify the " + "artifact_location for the op %s" % self.__class__.__name__) + + artifact_mode = kwargs.get('artifact_mode') + if artifact_mode: + params['artifact_mode'] = artifact_mode + return TFTProcessHandler(artifact_location=artifact_location, **params) + @tf.function def _split_string_with_delimiter(self, data, delimiter): """ diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 38ded6a809af..9f15db45bd28 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -711,8 +711,13 @@ def test_count_per_key_on_list(self): ])) def validate_count_per_key(key_vocab_filename): + files = os.listdir(self.artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) key_vocab_location = os.path.join( - self.artifact_location, 'transform_fn/assets', key_vocab_filename) + self.artifact_location, + files[0], + 'transform_fn/assets', + key_vocab_filename) with open(key_vocab_location, 'r') as f: key_vocab_list = [line.strip() for line in f] return key_vocab_list diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index 19bb02c5ae1b..fadf611b0e66 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -17,9 +17,11 @@ __all__ = ['ArtifactsFetcher'] +import os import typing import tensorflow_transform as tft +from apache_beam.ml.transforms import base class ArtifactsFetcher(): @@ -28,8 +30,18 @@ class ArtifactsFetcher(): to the TFTProcessHandlers in MLTransform. """ def __init__(self, artifact_location): - self.artifact_location = artifact_location - self.transform_output = tft.TFTransformOutput(self.artifact_location) + files = os.listdir(artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) + # TODO: https://github.com/apache/beam/issues/29356 + # Integrate ArtifactFetcher into MLTransform. + if len(files) > 1: + raise NotImplementedError( + "MLTransform may have been utilized alongside transforms written " + "in TensorFlow Transform, in conjunction with those from different " + "frameworks. Currently, retrieving artifacts from this " + "multi-framework setup is not supported.") + self._artifact_location = os.path.join(artifact_location, files[0]) + self.transform_output = tft.TFTransformOutput(self._artifact_location) def get_vocab_list( self, diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 4fd4b97e82cd..70eb78b6ffc6 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1020,13 +1020,13 @@ def __getitem__(self, type_param): class CollectionHint(CompositeTypeHint): """ A Collection type-hint. - + Collection[X] defines a type-hint for a collection of homogenous types. 'X' may be either a built-in Python type or another nested TypeConstraint. This represents a collections.abc.Collection type, which implements __contains__, __iter__, and __len__. This acts as a parent type for - sets but has fewer guarantees for mixins. + sets but has fewer guarantees for mixins. """ class CollectionTypeConstraint(SequenceTypeConstraint): def __init__(self, type_param): @@ -1302,6 +1302,8 @@ def is_consistent_with(sub, base): relation, but also handles the special Any type as well as type parameterization. """ + from apache_beam.pvalue import Row + from apache_beam.typehints.row_type import RowTypeConstraint if sub == base: # Common special case. return True @@ -1313,6 +1315,8 @@ def is_consistent_with(sub, base): return all(is_consistent_with(c, base) for c in sub.union_types) elif isinstance(base, TypeConstraint): return base._consistent_with_check_(sub) + elif isinstance(sub, RowTypeConstraint): + return base == Row elif isinstance(sub, TypeConstraint): # Nothing but object lives above any type constraints. return base == object diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index cd6018bfc1fa..6ec2cc0a7565 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -84,6 +84,7 @@ joblib==1.3.2 Js2Py==0.74 jsonschema==4.20.0 jsonschema-specifications==2023.11.2 +jsonpickle==3.0.2 mmh3==4.0.1 mock==5.1.0 nltk==3.8.1 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 1fae235ee477..435eb9712917 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -81,6 +81,7 @@ idna==3.6 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.20.0 jsonschema-specifications==2023.11.2 mmh3==4.0.1 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index ab4203ecbe37..51fb324d7c44 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -85,6 +85,7 @@ importlib-resources==6.1.1 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.20.0 jsonschema-specifications==2023.11.2 mmh3==4.0.1 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 308ffa736207..ce723259aa7c 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -84,6 +84,7 @@ iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 jsonschema==4.20.0 +jsonpickle==3.0.2 jsonschema-specifications==2023.11.2 mmh3==4.0.1 mock==5.1.0 diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 3c232e126ab7..82740ae67c9f 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -133,7 +133,9 @@ autodoc_inherit_docstrings = False autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", - "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers"] + "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", + "sentence_transformers", + ] # Allow a special section for documenting DataFrame API napoleon_custom_sections = ['Differences from pandas'] diff --git a/sdks/python/setup.py b/sdks/python/setup.py index e624f3176bb3..7e6d2217d757 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -298,6 +298,7 @@ def get_portability_package_data(): 'httplib2>=0.8,<0.23.0', 'js2py>=0.74,<1', 'jsonschema>=4.0.0,<5.0.0', + 'jsonpickle>=3.0.0,<4.0.0', # numpy can have breaking changes in minor versions. # Use a strict upper bound. 'numpy>=1.14.3,<1.25.0', # Update pyproject.toml as well. diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle index b1ed5f88c7c9..1e03b5058083 100644 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ b/sdks/python/test-suites/tox/py38/build.gradle @@ -141,6 +141,10 @@ toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}" test.dependsOn "testPy38transformers-430" preCommitPyCoverage.dependsOn "testPy38transformers-430" +toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}" +test.dependsOn "testPy38embeddingsMLTransform" +preCommitPyCoverage.dependsOn "testPy38embeddingsMLTransform" + toxTask "whitespacelint", "whitespacelint", "${posargs}" task archiveFilesToLint(type: Zip) { diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 28e282460e47..dbe90c084af2 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -419,3 +419,15 @@ commands = # Run all Vertex AI unit tests # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. /bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_vertex_ai {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret' + + +[testenv:py{38,39,310,311}-embeddings] +deps = + sentence-transformers==2.2.2 +extras = test,gcp +commands = + # Log aiplatform and its dependencies version for debugging + /bin/sh -c "pip freeze | grep -E sentence-transformers" + /bin/sh -c "pip freeze | grep -E google-cloud-aiplatform" + # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. + /bin/sh -c 'pytest apache_beam/ml/transforms/embeddings -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'