Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding support for dictionary writes to online store #4156

Merged
merged 10 commits into from
Apr 30, 2024
7 changes: 7 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,10 @@ def __init__(self, push_source_name: str):
class ReadOnlyRegistryException(Exception):
def __init__(self):
super().__init__("Registry implementation is read-only.")


class DataFrameSerializationError(Exception):
def __init__(self, input_dict: dict):
super().__init__(
f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}"
)
17 changes: 16 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from feast.dqm.errors import ValidationFailed
from feast.entity import Entity
from feast.errors import (
DataFrameSerializationError,
DataSourceRepeatNamesException,
EntityNotFoundException,
FeatureNameCollisionError,
Expand Down Expand Up @@ -1406,7 +1407,8 @@ def push(
def write_to_online_store(
self,
feature_view_name: str,
df: pd.DataFrame,
df: Optional[pd.DataFrame] = None,
inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tokoko updated

allow_registry_cache: bool = True,
):
"""
Expand All @@ -1415,6 +1417,7 @@ def write_to_online_store(
Args:
feature_view_name: The feature view to which the dataframe corresponds.
df: The dataframe to be persisted.
inputs: Optional the dictionary object to be written
allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry.
"""
# TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type
Expand All @@ -1426,6 +1429,18 @@ def write_to_online_store(
feature_view = self.get_feature_view(
feature_view_name, allow_registry_cache=allow_registry_cache
)
if df is not None and inputs is not None:
raise ValueError("Both df and inputs cannot be provided at the same time.")
if df is None and inputs is not None:
if isinstance(inputs, dict):
try:
df = pd.DataFrame(inputs)
except Exception as _:
raise DataFrameSerializationError(inputs)
elif isinstance(inputs, pd.DataFrame):
pass
else:
raise ValueError("inputs must be a dictionary or a pandas DataFrame.")
provider = self._get_provider()
provider.ingest_df(feature_view, df)

Expand Down
139 changes: 139 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2022 The Feast Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from datetime import datetime, timedelta
from typing import Any, Dict

from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.field import Field
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64


class TestOnlineWrites(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as data_dir:
self.store = FeatureStore(
config=RepoConfig(
project="test_write_to_online_store",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=2,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(
driver_entities, start_date, end_date
)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(
path=driver_stats_path, allow_truncated_timestamps=True
)

driver = Entity(name="driver", join_keys=["driver_id"])

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
mode="python",
)
def test_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
output: Dict[str, Any] = {
"conv_rate_plus_acc": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
]
}
return output

self.store.apply(
[
driver,
driver_stats_source,
driver_stats_fv,
test_view,
]
)
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)
# This will give the intuitive structure of the data as:
# {"driver_id": [..], "conv_rate": [..], "acc_rate": [..], "avg_daily_trips": [..]}
driver_dict = driver_df.to_dict(orient="list")
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats",
inputs=driver_dict,
)

def test_online_retrieval(self):
entity_rows = [
{
"driver_id": 1001,
}
]

online_python_response = self.store.get_online_features(
entity_rows=entity_rows,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"test_view:conv_rate_plus_acc",
],
).to_dict()

assert len(online_python_response) == 4
assert all(
key in online_python_response.keys()
for key in [
"driver_id",
"acc_rate",
"conv_rate",
"conv_rate_plus_acc",
]
)
Loading