Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Embeddings in mltransform #29564

Merged
merged 58 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
80e5c4a
Make base.py framework agnostic and add helper transforms
AnandInguva Nov 29, 2023
0d34847
Add tests for base.py
AnandInguva Nov 29, 2023
58b24f6
Add sentence-transformers
AnandInguva Nov 29, 2023
88f9ceb
Add tensorflow hub
AnandInguva Nov 29, 2023
23f7027
Add vertex_ai
AnandInguva Nov 29, 2023
04ebdb0
Make TFTProcessHandler a PTransform
AnandInguva Nov 29, 2023
f86c259
raise RuntimeError in ArtifactsFetcher when it is used for embeddings
AnandInguva Nov 29, 2023
fc4ec00
Add JsonPickle to requirements
AnandInguva Nov 29, 2023
3da5ce8
Add tox tests
AnandInguva Nov 29, 2023
4b4ee58
Mock frameworks in pydocs
AnandInguva Nov 29, 2023
01ba217
Add Row type check
AnandInguva Dec 4, 2023
f080c25
Remove requires_chaining
AnandInguva Dec 4, 2023
6111c31
change name of PTransformProvider to MLTransformProvider
AnandInguva Dec 4, 2023
ba24e81
remove batch_len in utility fun
AnandInguva Dec 4, 2023
d690aec
Change type annotation and redundant comments
AnandInguva Dec 4, 2023
af7496b
Remove get_transforms method
AnandInguva Dec 4, 2023
d713555
remove requires_chaining from tft
AnandInguva Dec 4, 2023
50450f3
add tests to sentence-transformers
AnandInguva Dec 4, 2023
c2b691f
Merge remote-tracking branch 'origin/master' into embeddings_mltransform
AnandInguva Dec 5, 2023
8823a75
Pass inference_args to RunInference
AnandInguva Dec 5, 2023
a7e2bd3
Add TODO GH issue
AnandInguva Dec 5, 2023
519b3ed
Merge branch 'embeddings_mltransform' of https://github.com/AnandIngu…
AnandInguva Dec 5, 2023
f77ae60
refactor variables in vertex_ai embeddings
AnandInguva Dec 5, 2023
95ed3c5
remove try/catch and throw error if options is empty for GCS artifact…
AnandInguva Dec 5, 2023
c235499
Refactor NotImplementedError message
AnandInguva Dec 5, 2023
6eebfa4
remove tensorflow hub from this PR
AnandInguva Dec 5, 2023
c27aabb
Add _validate_transform method
AnandInguva Dec 5, 2023
422a86a
add more tests
AnandInguva Dec 5, 2023
08b3665
fix test
AnandInguva Dec 6, 2023
91255ad
Fix test
AnandInguva Dec 6, 2023
c7237c3
Add more tests in sentence-transformer
AnandInguva Dec 6, 2023
a942885
use np.max instead of max
AnandInguva Dec 6, 2023
89c19fb
round to 2 decimals
AnandInguva Dec 6, 2023
2db4a20
Remove gradle command action
AnandInguva Dec 6, 2023
b7a48d5
Refactor throwing dataflow client exception
AnandInguva Dec 6, 2023
eb46e08
Merge branch 'embeddings_mltransform' of https://github.com/AnandIngu…
AnandInguva Dec 6, 2023
bad1b3b
skip the test if gcp is not installed
AnandInguva Dec 6, 2023
b850cee
remove toxTests for hub
AnandInguva Dec 6, 2023
ffff21a
remove toxTests for hub
AnandInguva Dec 6, 2023
88412ea
Fix values in assert for sentence_transformer_test
AnandInguva Dec 7, 2023
617f9d6
rename sentence_transformers to huggingface
AnandInguva Dec 7, 2023
5cae04b
fix pydocs
AnandInguva Dec 7, 2023
489200f
Change the model name for tests since it is getting different results…
AnandInguva Dec 7, 2023
816174a
Fix pydoc in vertexai
AnandInguva Dec 7, 2023
cfb1883
add suffix to artifact_location
AnandInguva Dec 8, 2023
2cb6f03
Revert "add suffix to artifact_location"
AnandInguva Dec 8, 2023
cd7050e
add no_xdist
AnandInguva Dec 8, 2023
98cd949
Try fixing pydoc for vertexai
AnandInguva Dec 8, 2023
8ea0906
change tox.ini to use pytest directly
AnandInguva Dec 8, 2023
5187b0e
Merge remote-tracking branch 'origin/master' into embeddings_mltransform
AnandInguva Dec 8, 2023
6f83d3c
raise FileExistError if Attribute file is already present
AnandInguva Dec 8, 2023
c9ddb25
Merge branch 'embeddings_mltransform' of https://github.com/AnandIngu…
AnandInguva Dec 8, 2023
9dce3cf
modify build.gradle to match tox task names
AnandInguva Dec 8, 2023
539c9ad
Add note to CHANGES.md
AnandInguva Dec 8, 2023
b967cd8
change gcs bucket to gs://temp-storage-for-perf-tests
AnandInguva Dec 8, 2023
f1bb42c
Add TODO GH links
AnandInguva Dec 11, 2023
8d0b47d
Merge remote-tracking branch 'origin/master' into embeddings_mltransform
AnandInguva Dec 11, 2023
c173d6a
Update CHANGES.md
AnandInguva Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ class ArtifactMode(object):
class MLTransformProvider:
"""
Data processing transforms that are intended to be used with MLTransform
should subclass MLTransformProvider and implement the following methods:
1. get_ptransform_for_processing()
should subclass MLTransformProvider and implement
get_ptransform_for_processing().

get_ptransform_for_processing() method should return a PTransform that can be
used to process the data.
Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(
if kwargs:
_LOGGER.warning("Ignoring the following arguments: %s", kwargs.keys())

# TODO: Add set_model_handler method.
# TODO:https://github.com/apache/beam/pull/29564 add set_model_handler method
@abc.abstractmethod
def get_model_handler(self) -> ModelHandler:
"""
Expand Down Expand Up @@ -398,6 +398,17 @@ def save_attributes(
artifact_location,
**kwargs,
):
# if an artifact location is present, instead of overwriting the
# existing file, raise an error since the same artifact location
# can be used by multiple beam jobs and this could result in undesired
# behavior.
if FileSystems.exists(FileSystems.join(artifact_location,
_ATTRIBUTE_FILE_NAME)):
raise FileExistsError(
"The artifact location %s already exists and contains %s. Please "
"specify a different location." %
(artifact_location, _ATTRIBUTE_FILE_NAME))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call - one possible future enhancement would be to support an overwrite argument that allows users to do this


if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
temp_dir = tempfile.mkdtemp()
temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
Expand Down
13 changes: 13 additions & 0 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,19 @@ def test_with_gcs_location_with_none_options(self):
self.attribute_manager.save_attributes(
ptransform_list=[], artifact_location=path)

def test_with_same_local_artifact_location(self):
artifact_location = self.artifact_location
attribute_manager = base._JsonPickleTransformAttributeManager()

ptransform_list = [RunInference(model_handler=FakeModelHandler())]

attribute_manager.save_attributes(
ptransform_list, artifact_location=artifact_location)

with self.assertRaises(FileExistsError):
attribute_manager.save_attributes([lambda x: x],
artifact_location=artifact_location)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest
import uuid

import numpy as np
import pytest
from parameterized import parameterized

import apache_beam as beam
Expand Down Expand Up @@ -80,10 +81,11 @@
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
@pytest.mark.no_xdist
class SentenceTrasformerEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
self.gcs_artifact_location = os.path.join(
'gs://apache-beam-ml/testing/sentence_transformers', uuid.uuid4().hex)

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
Expand Down Expand Up @@ -178,7 +180,6 @@ def test_sentence_transformer_with_int_data_types(self):

@parameterized.expand(_parameterized_inputs)
def test_with_gcs_artifact_location(self, inputs, model_name, output):
artifact_location = ('gs://apache-beam-ml/testing/sentence_transformers')
embedding_config = SentenceTransformerEmbeddings(
model_name=model_name, columns=[test_query_column])

Expand All @@ -187,8 +188,8 @@ def test_with_gcs_artifact_location(self, inputs, model_name, output):
p
| "CreateData" >> beam.Create(inputs)
| "MLTransform" >>
MLTransform(write_artifact_location=artifact_location).with_transform(
embedding_config))
MLTransform(write_artifact_location=self.gcs_artifact_location
).with_transform(embedding_config))
max_ele_pcoll = (
result_pcoll
| beam.Map(lambda x: round(np.max(x[test_query_column]), 2)))
Expand All @@ -200,7 +201,7 @@ def test_with_gcs_artifact_location(self, inputs, model_name, output):
p
| "CreateData" >> beam.Create(inputs)
| "MLTransform" >>
MLTransform(read_artifact_location=artifact_location))
MLTransform(read_artifact_location=self.gcs_artifact_location))
max_ele_pcoll = (
result_pcoll
| beam.Map(lambda x: round(np.max(x[test_query_column]), 2)))
Expand Down Expand Up @@ -233,7 +234,7 @@ def assert_element(element):
| beam.Map(lambda x: x[test_query_column])
| beam.Map(assert_element))

def test_mltransform_to_ptransform_with_vertex(self):
def test_mltransform_to_ptransform_with_sentence_transformer(self):
model_name = ''
transforms = [
SentenceTransformerEmbeddings(columns=['x'], model_name=model_name),
Expand Down
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ def __init__(
Args:
model_name: The name of the Vertex AI Text Embedding model.
columns: The columns containing the text to be embedded.
task_type: The downstream task for the embeddings.
Valid values are RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT,
SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING.
task_type: The downstream task for the embeddings. Valid values are
RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT, SEMANTIC_SIMILARITY,
CLASSIFICATION, CLUSTERING. For more information on the task type,
look at https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings # pylint: disable=line-too-long
title: Identifier of the text content.
project: The default GCP project for API calls.
location: The default location for API calls.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest
import uuid

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
Expand Down Expand Up @@ -44,7 +46,9 @@
VertexAITextEmbeddings is None, 'Vertex AI Python SDK is not installed.')
class VertexAIEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()
self.artifact_location = tempfile.mkdtemp(prefix='_vertex_ai_test')
self.gcs_artifact_location = os.path.join(
'gs://apache-beam-ml/testing/vertex_ai', uuid.uuid4().hex)
damccorm marked this conversation as resolved.
Show resolved Hide resolved

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)
Expand Down Expand Up @@ -158,7 +162,6 @@ def test_with_int_data_types(self):
embedding_config))

def test_with_gcs_artifact_location(self):
artifact_location = ('gs://apache-beam-ml/testing/vertex_ai')
with beam.Pipeline() as p:
embedding_config = VertexAITextEmbeddings(
model_name=model_name, columns=[test_query_column])
Expand All @@ -172,7 +175,7 @@ def test_with_gcs_artifact_location(self):
_ = self.pipeline_with_configurable_artifact_location(
pipeline=data,
embedding_config=embedding_config,
write_artifact_location=artifact_location)
write_artifact_location=self.gcs_artifact_location)

with beam.Pipeline() as p:
data = (
Expand All @@ -183,7 +186,7 @@ def test_with_gcs_artifact_location(self):
test_query_column: test_query
}]))
result_pcoll = self.pipeline_with_configurable_artifact_location(
pipeline=data, read_artifact_location=artifact_location)
pipeline=data, read_artifact_location=self.gcs_artifact_location)

def assert_element(element):
assert round(element, 2) == 0.15
Expand Down
4 changes: 0 additions & 4 deletions sdks/python/apache_beam/ml/transforms/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"artifact_location is not specified. Please specify the "
"artifact_location for the op %s" % self.__class__.__name__)

transforms = kwargs.get('transforms')
if transforms:
params['transforms'] = transforms

artifact_mode = kwargs.get('artifact_mode')
if artifact_mode:
params['artifact_mode'] = artifact_mode
Expand Down
5 changes: 3 additions & 2 deletions sdks/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,13 @@ commands =
/bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_vertex_ai {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'


[testenv:py{38,39,310,311}-sentence-transformers-222]
[testenv:py{38,39,310,311}-embeddings]
deps =
sentence-transformers==2.2.2
extras = test,gcp
commands =
# Log aiplatform and its dependencies version for debugging
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
/bin/sh -c "pip freeze | grep -E sentence-transformers"
/bin/sh -c "pip freeze | grep -E google-cloud-aiplatform"
# Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
/bin/sh -c 'pytest apache_beam/ml/transforms/embeddings -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'
/bin/sh -c 'pytest apache_beam/ml/transforms/embeddings -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'