Skip to content

Commit

Permalink
Add some arrow replay code
Browse files Browse the repository at this point in the history
  • Loading branch information
aandres committed Nov 26, 2023
1 parent a2e6b9a commit d8026ec
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
39 changes: 39 additions & 0 deletions beavers/pyarrow_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Callable

import pandas as pd
import pyarrow as pa

from beavers.engine import UTC_MAX
from beavers.replay import DataSource


class ArrowTableDataSource(DataSource[pa.Table]):
def __init__(
self, table: pa.Table, timestamp_extractor: Callable[[pa.Table], pa.Array]
):
assert callable(timestamp_extractor)
self._table = table
self._empty_table = table.schema.empty_table()
self._timestamp_column = timestamp_extractor(table).to_pandas(
date_as_object=False
)
assert (
self._timestamp_column.is_monotonic_increasing
), "Timestamp column should be monotonic increasing"
self._index = 0

def read_to(self, timestamp: pd.Timestamp) -> pa.Table:
new_index = self._timestamp_column.searchsorted(timestamp, side="right")
if new_index > self._index:
from_index = self._index
self._index = new_index
return self._table.slice(from_index, new_index - from_index)
else:
results = self._empty_table
return results

def get_next(self) -> pd.Timestamp:
if self._index >= len(self._table):
return UTC_MAX
else:
return self._timestamp_column.iloc[self._index]
49 changes: 49 additions & 0 deletions tests/test_pyarrow_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from operator import itemgetter

import pandas as pd
import pyarrow as pa
import pytest

from beavers.engine import UTC_MAX
from beavers.pyarrow_replay import ArrowTableDataSource


def test_arrow_table_data_source():
table = pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-01T00:00:00Z"),
pd.to_datetime("2023-01-02T00:00:00Z"),
],
"value": [1, 2],
}
)
source = ArrowTableDataSource(
table,
itemgetter("timestamp"),
)
assert source.get_next() == pd.to_datetime("2023-01-01T00:00:00Z")
assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == table[:1]
assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == table[:0]
assert source.get_next() == pd.to_datetime("2023-01-02T00:00:00Z")
assert source.read_to(pd.to_datetime("2023-01-02T00:00:00Z")) == table[1:]
assert source.get_next() == UTC_MAX
assert source.read_to(UTC_MAX) == table[:0]


def test_arrow_table_data_source_ooo():
with pytest.raises(
AssertionError, match="Timestamp column should be monotonic increasing"
):
ArrowTableDataSource(
pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-02T00:00:00Z"),
pd.to_datetime("2023-01-01T00:00:00Z"),
],
"value": [1, 2],
}
),
itemgetter("timestamp"),
)

0 comments on commit d8026ec

Please sign in to comment.