Skip to content

Commit

Permalink
[ISSUE #10552] move stream slicer concept in concurrent CDK
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Nov 12, 2024
1 parent 4aaf1e7 commit 35290eb
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 264 deletions.
21 changes: 13 additions & 8 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
)
from airbyte_cdk.sources.declarative.requesters import HttpRequester
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartitionFactory,
StreamSlicerPartitionGenerator,
)
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AlwaysAvailableAvailabilityStrategy,
)
Expand Down Expand Up @@ -228,13 +231,15 @@ def _group_streams(
)
declarative_stream.retriever.cursor = None

partition_generator = CursorPartitionGenerator(
stream=declarative_stream,
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
cursor=cursor,
connector_state_converter=connector_state_converter,
cursor_field=[cursor.cursor_field.cursor_field_key],
slice_boundary_fields=cursor.slice_boundary_fields,

partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
declarative_stream.retriever,
self.message_repository,
),
cursor,
)

concurrent_streams.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
return self._source_config

@property
def message_repository(self) -> Union[None, MessageRepository]:
def message_repository(self) -> MessageRepository:
return self._message_repository

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from typing import Iterable, Optional, Mapping, Any

from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


class DeclarativePartitionFactory:
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository) -> None:
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever # FIXME: it should be a retriever_factory here to ensure that paginators and other classes don't share interal/class state
self._message_repository = message_repository

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
self._stream_name,
self._json_schema,
self._retriever,
self._message_repository,
stream_slice,
)


class DeclarativePartition(Partition):
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository, stream_slice: StreamSlice):
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)

def read(self) -> Iterable[Record]:
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
if isinstance(stream_data, Mapping):
# TODO validate if this is necessary: self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
yield Record(stream_data, self)
else:
self._message_repository.emit_message(stream_data)

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice

def stream_name(self) -> str:
return self._stream_name

def __hash__(self) -> int:
return self._hash


class StreamSlicerPartitionGenerator(PartitionGenerator):
def __init__(self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer) -> None:
self._partition_factory = partition_factory
self._stream_slicer = stream_slicer

def generate(self) -> Iterable[Partition]:
for stream_slice in self._stream_slicer.stream_slices():
yield self._partition_factory.create(stream_slice)
18 changes: 4 additions & 14 deletions airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable
from abc import ABC

from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
RequestOptionsProvider,
)
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer as ConcurrentStreamSlicer


@dataclass
class StreamSlicer(RequestOptionsProvider):
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
"""
Slices the stream into a subset of records.
Slices enable state checkpointing and data retrieval parallelization.
Expand All @@ -22,11 +19,4 @@ class StreamSlicer(RequestOptionsProvider):
See the stream slicing section of the docs for more information.
"""

@abstractmethod
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Defines stream slices
:return: List of stream slices
"""
pass
87 changes: 4 additions & 83 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from deprecated.classic import deprecated

from airbyte_cdk.utils.slice_hasher import SliceHasher

"""
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
"""
Expand Down Expand Up @@ -270,6 +272,7 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._hash = SliceHasher.hash(self._stream.name, self._slice)

def read(self) -> Iterable[Record]:
"""
Expand Down Expand Up @@ -309,12 +312,7 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._slice

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder)
return hash((self._stream.name, s))
else:
return hash(self._stream.name)
return self._hash

def stream_name(self) -> str:
return self._stream.name
Expand Down Expand Up @@ -363,83 +361,6 @@ def generate(self) -> Iterable[Partition]:
)


class CursorPartitionGenerator(PartitionGenerator):
"""
This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions.
It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained
across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state.
"""

_START_BOUNDARY = 0
_END_BOUNDARY = 1

def __init__(
self,
stream: Stream,
message_repository: MessageRepository,
cursor: Cursor,
connector_state_converter: DateTimeStreamStateConverter,
cursor_field: Optional[List[str]],
slice_boundary_fields: Optional[Tuple[str, str]],
):
"""
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.
:param stream: The stream to delegate to for partition generation.
:param message_repository: The message repository to use to emit non-record messages.
:param sync_mode: The synchronization mode.
:param cursor: A Cursor object that maintains the state and the cursor field.
"""
self._stream = stream
self.message_repository = message_repository
self._sync_mode = SyncMode.full_refresh
self._cursor = cursor
self._cursor_field = cursor_field
self._state = self._cursor.state
self._slice_boundary_fields = slice_boundary_fields
self._connector_state_converter = connector_state_converter

def generate(self) -> Iterable[Partition]:
"""
Generate partitions based on the slices in the cursor's state.
This method iterates through the list of slices found in the cursor's state, and for each slice, it generates
a `StreamPartition` object.
:return: An iterable of StreamPartition objects.
"""

start_boundary = (
self._slice_boundary_fields[self._START_BOUNDARY]
if self._slice_boundary_fields
else "start"
)
end_boundary = (
self._slice_boundary_fields[self._END_BOUNDARY]
if self._slice_boundary_fields
else "end"
)

for slice_start, slice_end in self._cursor.generate_slices():
stream_slice = StreamSlice(
partition={},
cursor_slice={
start_boundary: self._connector_state_converter.output_format(slice_start),
end_boundary: self._connector_state_converter.output_format(slice_end),
},
)

yield StreamPartition(
self._stream,
copy.deepcopy(stream_slice),
self.message_repository,
self._sync_mode,
self._cursor_field,
self._state,
)


@deprecated(
"Availability strategy has been soft deprecated. Do not use. Class is subject to removal",
category=ExperimentalClassWarning,
Expand Down
43 changes: 26 additions & 17 deletions airbyte_cdk/sources/streams/concurrent/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import (
AbstractStreamStateConverter,
)
from airbyte_cdk.sources.types import StreamSlice


def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
Expand Down Expand Up @@ -61,7 +63,7 @@ def extract_value(self, record: Record) -> CursorValueType:
return cursor_value # type: ignore # we assume that the value the path points at is a comparable


class Cursor(ABC):
class Cursor(StreamSlicer, ABC):
@property
@abstractmethod
def state(self) -> MutableMapping[str, Any]: ...
Expand All @@ -88,12 +90,12 @@ def ensure_at_least_one_state_emitted(self) -> None:
"""
raise NotImplementedError()

def generate_slices(self) -> Iterable[Tuple[Any, Any]]:
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Default placeholder implementation of generate_slices.
Subclasses can override this method to provide actual behavior.
"""
yield from ()
yield StreamSlice(partition={}, cursor_slice={})


class FinalStateCursor(Cursor):
Expand Down Expand Up @@ -184,8 +186,8 @@ def cursor_field(self) -> CursorField:
return self._cursor_field

@property
def slice_boundary_fields(self) -> Optional[Tuple[str, str]]:
return self._slice_boundary_fields
def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]:
return self._slice_boundary_fields if self._slice_boundary_fields else (self._connector_state_converter.START_KEY, self._connector_state_converter.END_KEY)

def _get_concurrent_state(
self, state: MutableMapping[str, Any]
Expand Down Expand Up @@ -299,7 +301,7 @@ def ensure_at_least_one_state_emitted(self) -> None:
"""
self._emit_state_message()

def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Generating slices based on a few parameters:
* lookback_window: Buffer to remove from END_KEY of the highest slice
Expand Down Expand Up @@ -368,7 +370,7 @@ def _calculate_lower_boundary_of_last_slice(

def _split_per_slice_range(
self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool
) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
) -> Iterable[StreamSlice]:
if lower >= upper:
return

Expand All @@ -377,10 +379,14 @@ def _split_per_slice_range(

lower = max(lower, self._start) if self._start else lower
if not self._slice_range or self._evaluate_upper_safely(lower, self._slice_range) >= upper:
if self._cursor_granularity and not upper_is_end:
yield lower, upper - self._cursor_granularity
else:
yield lower, upper
start_value, end_value = (lower, upper - self._cursor_granularity) if self._cursor_granularity and not upper_is_end else (lower, upper)
yield StreamSlice(
partition={},
cursor_slice={
self._slice_boundary_fields_wrapper[self._START_BOUNDARY]: self._connector_state_converter.output_format(start_value),
self._slice_boundary_fields_wrapper[self._END_BOUNDARY]: self._connector_state_converter.output_format(end_value)
}
)
else:
stop_processing = False
current_lower_boundary = lower
Expand All @@ -389,12 +395,15 @@ def _split_per_slice_range(
self._evaluate_upper_safely(current_lower_boundary, self._slice_range), upper
)
has_reached_upper_boundary = current_upper_boundary >= upper
if self._cursor_granularity and (
not upper_is_end or not has_reached_upper_boundary
):
yield current_lower_boundary, current_upper_boundary - self._cursor_granularity
else:
yield current_lower_boundary, current_upper_boundary

start_value, end_value = (current_lower_boundary, current_upper_boundary - self._cursor_granularity) if self._cursor_granularity and (not upper_is_end or not has_reached_upper_boundary) else (current_lower_boundary, current_upper_boundary)
yield StreamSlice(
partition={},
cursor_slice={
self._slice_boundary_fields_wrapper[self._START_BOUNDARY]: self._connector_state_converter.output_format(start_value),
self._slice_boundary_fields_wrapper[self._END_BOUNDARY]: self._connector_state_converter.output_format(end_value)
}
)
current_lower_boundary = current_upper_boundary
if current_upper_boundary >= upper:
stop_processing = True
Expand Down
21 changes: 21 additions & 0 deletions airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from abc import ABC, abstractmethod
from typing import Iterable

from airbyte_cdk.sources.types import StreamSlice


class StreamSlicer(ABC):
"""
Slices the stream into chunks that can be fetched independently. Slices enable state checkpointing and data retrieval parallelization.
"""

@abstractmethod
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Defines stream slices
:return: An iterable of stream slices
"""
pass
Loading

0 comments on commit 35290eb

Please sign in to comment.