Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] BigQuery handler for enrichment transform #31295

Merged
merged 8 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
}

1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
* Beam YAML now supports the jinja templating syntax.
Template variables can be passed with the (json-formatted) `--jinja_variables` flag.
* DataFrame API now supports pandas 2.1.x and adds 12 more string functions for Series.([#31185](https://github.com/apache/beam/pull/31185)).
* Added BigQuery handler for enrichment transform (Python) ([#31295](https://github.com/apache/beam/pull/31295))

## Breaking Changes

Expand Down
121 changes: 79 additions & 42 deletions sdks/python/apache_beam/io/requestresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
Expand All @@ -42,6 +44,7 @@
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics import Metrics
from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
from apache_beam.transforms.util import BatchElements
from apache_beam.utils import retry

RequestT = TypeVar('RequestT')
Expand Down Expand Up @@ -143,6 +146,10 @@ def get_cache_key(self, request: RequestT) -> str:
"""
return ""

def batch_elements_kwargs(self) -> Mapping[str, Any]:
"""Returns a kwargs suitable for `beam.BatchElements`."""
return {}


class ShouldBackOff(abc.ABC):
"""
Expand Down Expand Up @@ -476,53 +483,67 @@ def __init__(
def __enter__(self):
self.client = redis.Redis(self.host, self.port, **self.kwargs)

def __call__(self, element, *args, **kwargs):
if self.mode == _RedisMode.READ:
cache_request = self.source_caller.get_cache_key(element)
# check if the caller is a enrichment handler. EnrichmentHandler
# provides the request format for cache.
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
else:
encoded_request = self.request_coder.encode(element)

encoded_response = self.client.get(encoded_request)
if not encoded_response:
# no cache entry present for this request.
def _read_cache(self, element):
cache_request = self.source_caller.get_cache_key(element)
# check if the caller is a enrichment handler. EnrichmentHandler
# provides the request format for cache.
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
else:
encoded_request = self.request_coder.encode(element)

encoded_response = self.client.get(encoded_request)
if not encoded_response:
# no cache entry present for this request.
return element, None

if self.response_coder is None:
try:
response_dict = json.loads(encoded_response.decode('utf-8'))
response = beam.Row(**response_dict)
except Exception:
_LOGGER.warning(
'cannot decode response from redis cache for %s.' % element)
return element, None
else:
response = self.response_coder.decode(encoded_response)
return element, response

if self.response_coder is None:
try:
response_dict = json.loads(encoded_response.decode('utf-8'))
response = beam.Row(**response_dict)
except Exception:
_LOGGER.warning(
'cannot decode response from redis cache for %s.' % element)
return element, None
else:
response = self.response_coder.decode(encoded_response)
return element, response
def _write_cache(self, element):
cache_request = self.source_caller.get_cache_key(element[0])
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
else:
encoded_request = self.request_coder.encode(element[0])
if self.response_coder is None:
try:
encoded_response = json.dumps(element[1]._asdict()).encode('utf-8')
except Exception:
_LOGGER.warning(
'cannot encode response %s for %s to store in '
'redis cache.' % (element[1], element[0]))
return element
else:
cache_request = self.source_caller.get_cache_key(element[0])
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
encoded_response = self.response_coder.encode(element[1])
# Write to cache with TTL. Set nx to True to prevent overwriting for the
# same key.
self.client.set(
encoded_request, encoded_response, self.time_to_live, nx=True)
return element

def __call__(self, element, *args, **kwargs):
if self.mode == _RedisMode.READ:
if isinstance(element, List):
responses = [self._read_cache(e) for e in element]
return responses
else:
encoded_request = self.request_coder.encode(element[0])
if self.response_coder is None:
try:
encoded_response = json.dumps(element[1]._asdict()).encode('utf-8')
except Exception:
_LOGGER.warning(
'cannot encode response %s for %s to store in '
'redis cache.' % (element[1], element[0]))
return element
return self._read_cache(element)
else:
if isinstance(element, List):
responses = [self._write_cache(e) for e in element]
return responses
else:
encoded_response = self.response_coder.encode(element[1])
# Write to cache with TTL. Set nx to True to prevent overwriting for the
# same key.
self.client.set(
encoded_request, encoded_response, self.time_to_live, nx=True)
return element
return self._write_cache(element)

def __exit__(self, exc_type, exc_val, exc_tb):
self.client.close()
Expand Down Expand Up @@ -708,6 +729,13 @@ def request_coder(self, request_coder: coders.Coder):
self._request_coder = request_coder


class FlattenBatch(beam.DoFn):
"""Flatten a batched PCollection."""
def process(self, elements, *args, **kwargs):
for element in elements:
yield element


class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""A :class:`RequestResponseIO` transform to read and write to APIs.
Expand Down Expand Up @@ -753,6 +781,7 @@ def __init__(
self._repeater = NoOpsRepeater()
self._cache = cache
self._throttler = throttler
self._batching_kwargs = self._caller.batch_elements_kwargs()

def expand(
self,
Expand All @@ -774,6 +803,10 @@ def expand(
).with_outputs(
'cache_misses', main='cached_responses'))

# Batch elements if batching is enabled.
if self._batching_kwargs:
inputs = inputs | BatchElements(**self._batching_kwargs)

if isinstance(self._throttler, DefaultThrottler):
# DefaultThrottler applies throttling in the DoFn of
# Call PTransform.
Expand All @@ -796,6 +829,10 @@ def expand(
should_backoff=self._should_backoff,
repeater=self._repeater))

# if batching is enabled then handle accordingly.
if self._batching_kwargs:
responses = responses | "FlattenBatch" >> beam.ParDo(FlattenBatch())

if self._cache:
# write to cache.
_ = responses | self._cache.get_write()
Expand Down
14 changes: 8 additions & 6 deletions sdks/python/apache_beam/transforms/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ def expand(self,
if self._cache:
self._cache.request_coder = request_coder

fetched_data = input_row | RequestResponseIO(
caller=self._source_handler,
timeout=self._timeout,
repeater=self._repeater,
cache=self._cache,
throttler=self._throttler)
fetched_data = (
input_row
| "Enrichment-RRIO" >> RequestResponseIO(
caller=self._source_handler,
timeout=self._timeout,
repeater=self._repeater,
cache=self._cache,
throttler=self._throttler))

# EnrichmentSourceHandler returns a tuple of (request,response).
return (
Expand Down
Loading
Loading