Skip to content

Commit

Permalink
add feast feature store handler
Browse files Browse the repository at this point in the history
  • Loading branch information
riteshghorse committed Apr 12, 2024
1 parent e33dec6 commit 29ee274
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 3 deletions.
3 changes: 0 additions & 3 deletions .github/trigger_files/beam_PostCommit_Python.json

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import tempfile
from pathlib import Path
from typing import List

from feast import FeatureStore

import apache_beam as beam
from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel

__all__ = [
'FeastFeatureStoreEnrichmentHandler',
]

_LOGGER = logging.getLogger(__name__)

LOCAL_FEATURE_STORE_YAML_FILENAME = 'fs_yaml_file.yaml'


def download_fs_yaml_file(gcs_fs_yaml_file: str):
"""Download the feature store config file for Feast."""
try:
fs = GCSFileSystem(pipeline_options={})
with fs.open(gcs_fs_yaml_file, 'r') as gcs_file:
with tempfile.NamedTemporaryFile(suffix=LOCAL_FEATURE_STORE_YAML_FILENAME,
delete=False) as local_file:
local_file.write(gcs_file.read())
return Path(local_file.name)
except Exception:
raise RuntimeError(
'error downloading the file %s locally to load the '
'Feast feature store.')


def _validate_feature_names(feature_names, feature_service_name):
"""Validate either `feature_names` or `feature_service_name` is provided."""
if not bool(feature_names or feature_service_name):
raise ValueError(
'Please provide either a list of feature names to fetch '
'from online store or a feature service name for the '
'online store!')


class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
beam.Row]):
"""Enrichment handler to interact with the Feast feature store.
Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment`
transform.
To filter the features to enrich, use the `join_fn` param in
:class:`apache_beam.transforms.enrichment.Enrichment`.
"""
def __init__(
self,
entity_id: str,
feature_store_yaml_path: str,
feature_names: List[str] = None,
feature_service_name: str = "",
full_feature_names: bool = False,
*,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
):
"""Initializes an instance of `FeastFeatureStoreEnrichmentHandler`.
Args:
entity_id (str): entity name for the entity associated with the features.
feature_store_yaml_path (str): The path to a YAML configuration file for
the Feast feature store.
feature_names: A list of feature names to be retrieved from the online
Feast feature store. The `feature_names` will be ignored if
`feature_service_name` is also provided.
feature_service_name (str): The name of the feature service containing
the features to fetch from the online Feast feature store.
full_feature_names (bool): Whether to use full feature names
(including namespaces, etc.). Defaults to False.
exception_level: a `enum.Enum` value from
`apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`
to set the level when `None` feature values are fetched from the
online Feast store. Defaults to `ExceptionLevel.WARN`.
"""
self.entity_id = entity_id
self.feature_store_yaml_path = feature_store_yaml_path
self.feature_names = feature_names
self.feature_service_name = feature_service_name
self.full_feature_names = full_feature_names
self._exception_level = exception_level
_validate_feature_names(self.feature_names, self.feature_service_name)

def __enter__(self):
"""Connect with the Feast Feature Store."""
local_repo_path = download_fs_yaml_file(self.feature_store_yaml_path)
try:
self.store = FeatureStore(fs_yaml_file=local_repo_path)
except Exception:
raise RuntimeError(
'Invalid feature store yaml file provided. Make sure '
'the `feature_store_yaml_path` contains the valid '
'configuration for Feast feature store.')
if self.feature_service_name:
try:
self.features = self.store.get_feature_service(
self.feature_service_name)
except Exception:
raise RuntimeError(
'Could find the feature service %s for the feature '
'store configured in `feature_store_yaml_path`.')
else:
self.features = self.feature_names

def __call__(self, request: beam.Row, *args, **kwargs):
"""Fetches feature values for an entity-id from the Feast feature store.
Args:
request: the input `beam.Row` to enrich.
"""
request_dict = request._asdict()
feature_values = self.store.get_online_features(
features=self.features,
entity_rows=[{
self.entity_id: request_dict[self.entity_id]
}],
full_feature_names=self.full_feature_names).to_dict()
# get_online_features() returns a list of feature values per entity-id.
# Since we do this per entity, the list of feature values only contain
# a single element at position 0.
response_dict = {k: v[0] for k, v in feature_values.items()}
return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean the instantiated Feast feature store client."""
self.store = None

def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with unique entity-id for the feature values.
"""
return 'entity_id: %s' % request._asdict()[self.entity_id]
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests Feast feature store enrichment handler for enrichment transform.
See https://s.apache.org/feast-enrichment-test-setup
to set up test feast feature repository.
"""

import unittest

import pytest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline

# pylint: disable=ungrouped-imports
try:
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.feast_feature_store import \
FeastFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import ValidateResponse # pylint: disable=line-too-long
except ImportError:
raise unittest.SkipTest(
'Feast feature store test dependencies are not installed.')


@pytest.mark.uses_feast
class TestFeastEnrichmentHandler(unittest.TestCase):
def setUp(self) -> None:
self.feature_store_yaml_file = (
'gs://apache-beam-testing-enrichment/'
'feast-feature-store/repos/ecommerce/'
'feature_repo/feature_store.yaml')
self.feature_service_name = 'demograph_service'

def test_feast_enrichment(self):
requests = [
beam.Row(user_id=2, product_id=1),
beam.Row(user_id=6, product_id=2),
beam.Row(user_id=9, product_id=3),
]
expected_fields = [
'user_id', 'product_id', 'state', 'country', 'gender', 'age'
]
handler = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name=self.feature_service_name,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| beam.Create(requests)
| Enrichment(handler)
| beam.ParDo(ValidateResponse(expected_fields)))

def test_feast_enrichment_bad_feature_service_name(self):
"""Test raising an error when a bad feature service name is given."""
requests = [
beam.Row(user_id=1, product_id=1),
]
handler = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name="bad_name",
)

with self.assertRaises(RuntimeError):
test_pipeline = beam.Pipeline()
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))
res = test_pipeline.run()
res.wait_until_finish()

def test_feast_enrichment_bad_yaml_path(self):
"""Test raising an error when wrong yaml file is passed."""
requests = [
beam.Row(user_id=1, product_id=1),
]

with self.assertRaises(RuntimeError):
handler = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path='gs://bad_path',
feature_service_name="bad_name",
)
test_pipeline = beam.Pipeline()
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))
res = test_pipeline.run()
res.wait_until_finish()

def test_feast_enrichment_no_feature_service(self):
"""Test raising an error in case of no feature service name."""
with self.assertRaises(ValueError):
_ = FeastFeatureStoreEnrichmentHandler(
entity_id='user_id',
feature_store_yaml_path=self.feature_store_yaml_file,
)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

feast[gcp]
1 change: 1 addition & 0 deletions sdks/python/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ markers =
vertex_ai_postcommit: vertex ai postcommits that need additional deps.
uses_testcontainer: tests that use testcontainers.
uses_mock_api: tests that uses the mock API cluster.
uses_feast: tests that uses feast in some way

# Default timeout intended for unit tests.
# If certain tests need a different value, please see the docs on how to
Expand Down
28 changes: 28 additions & 0 deletions sdks/python/test-suites/direct/common.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,33 @@ task testcontainersTest {
}
}

// Integration tests that uses feast
task feastIntegrationTest {
dependsOn 'installGcpTest'
dependsOn ':sdks:python:sdist'
def requirementsFile = "${rootDir}/sdks/python/apache_beam/transforms/enrichment_handlers/feast_tests_requirements.txt"
doFirst {
exec {
executable 'sh'
args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
}
}
doLast {
def testOpts = basicTestOpts
def argMap = [
"test_opts": testOpts,
"suite": "postCommitIT-direct-py${pythonVersionSuffix}",
"collect": "uses_feast",
"runner": "TestDirectRunner"
]
def cmdArgs = mapToArgString(argMap)
exec {
executable 'sh'
args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
}
}
}

// Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
project.tasks.register("inferencePostCommitIT") {
dependsOn = [
Expand All @@ -401,6 +428,7 @@ project.tasks.register("inferencePostCommitIT") {
'xgboostInferenceTest',
'transformersInferenceTest',
'testcontainersTest',
'feastIntegrationTest',
// (TODO) https://github.com/apache/beam/issues/25799
// uncomment tfx bsl tests once tfx supports protobuf 4.x
// 'tfxInferenceTest',
Expand Down

0 comments on commit 29ee274

Please sign in to comment.