Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(refactor): refactor partition generator to take any stream slicer #39

Merged
merged 17 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def on_partition_complete_sentinel(

try:
if sentinel.is_successful:
partition.close()
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
Expand Down
21 changes: 13 additions & 8 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
)
from airbyte_cdk.sources.declarative.requesters import HttpRequester
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartitionFactory,
StreamSlicerPartitionGenerator,
)
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AlwaysAvailableAvailabilityStrategy,
)
Expand Down Expand Up @@ -228,13 +231,15 @@ def _group_streams(
)
declarative_stream.retriever.cursor = None

partition_generator = CursorPartitionGenerator(
stream=declarative_stream,
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
cursor=cursor,
connector_state_converter=connector_state_converter,
cursor_field=[cursor.cursor_field.cursor_field_key],
slice_boundary_fields=cursor.slice_boundary_fields,

partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
declarative_stream.retriever,
self.message_repository,
),
cursor,
)

concurrent_streams.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
return self._source_config

@property
def message_repository(self) -> Union[None, MessageRepository]:
def message_repository(self) -> MessageRepository:
return self._message_repository

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
maxi297 marked this conversation as resolved.
Show resolved Hide resolved

from typing import Iterable, Optional, Mapping, Any

from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


class DeclarativePartitionFactory:
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository) -> None:
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever # FIXME: it should be a retriever_factory here to ensure that paginators and other classes don't share interal/class state
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
self._message_repository = message_repository

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
self._stream_name,
self._json_schema,
self._retriever,
self._message_repository,
stream_slice,
)


class DeclarativePartition(Partition):
def __init__(self, stream_name: str, json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository, stream_slice: StreamSlice):
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)

def read(self) -> Iterable[Record]:
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
if isinstance(stream_data, Mapping):
# TODO validate if this is necessary: self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
yield Record(stream_data, self)
else:
self._message_repository.emit_message(stream_data)

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice

def stream_name(self) -> str:
return self._stream_name

def __hash__(self) -> int:
return self._hash


class StreamSlicerPartitionGenerator(PartitionGenerator):
def __init__(self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer) -> None:
self._partition_factory = partition_factory
self._stream_slicer = stream_slicer

def generate(self) -> Iterable[Partition]:
for stream_slice in self._stream_slicer.stream_slices():
yield self._partition_factory.create(stream_slice)
18 changes: 4 additions & 14 deletions airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable
from abc import ABC

from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
RequestOptionsProvider,
)
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer as ConcurrentStreamSlicer


@dataclass
class StreamSlicer(RequestOptionsProvider):
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
"""
Slices the stream into a subset of records.
Slices enable state checkpointing and data retrieval parallelization.
Expand All @@ -22,11 +19,4 @@ class StreamSlicer(RequestOptionsProvider):

See the stream slicing section of the docs for more information.
"""

@abstractmethod
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Defines stream slices

:return: List of stream slices
"""
pass
11 changes: 0 additions & 11 deletions airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,13 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._slice = _slice
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
try:
Expand Down Expand Up @@ -289,13 +286,6 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
file = self._slice["files"][0]
return {"files": [file]}

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
Expand Down Expand Up @@ -352,7 +342,6 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
)
self._cursor.set_pending_partitions(pending_partitions)
Expand Down
102 changes: 4 additions & 98 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from deprecated.classic import deprecated

from airbyte_cdk.utils.slice_hasher import SliceHasher

"""
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
"""
Expand Down Expand Up @@ -96,7 +98,6 @@ def create_from_stream(
else SyncMode.incremental,
[cursor_field] if cursor_field is not None else None,
state,
cursor,
),
name=stream.name,
namespace=stream.namespace,
Expand Down Expand Up @@ -259,7 +260,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -272,8 +272,7 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False
self._hash = SliceHasher.hash(self._stream.name, self._slice)

def read(self) -> Iterable[Record]:
"""
Expand Down Expand Up @@ -313,23 +312,11 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._slice

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder)
return hash((self._stream.name, s))
else:
return hash(self._stream.name)
return self._hash

def stream_name(self) -> str:
return self._stream.name

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __repr__(self) -> str:
return f"StreamPartition({self._stream.name}, {self._slice})"

Expand All @@ -349,7 +336,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -360,7 +346,6 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor

def generate(self) -> Iterable[Partition]:
for s in self._stream.stream_slices(
Expand All @@ -373,85 +358,6 @@ def generate(self) -> Iterable[Partition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


class CursorPartitionGenerator(PartitionGenerator):
"""
This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions.

It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained
across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state.
"""

_START_BOUNDARY = 0
_END_BOUNDARY = 1

def __init__(
self,
stream: Stream,
message_repository: MessageRepository,
cursor: Cursor,
connector_state_converter: DateTimeStreamStateConverter,
cursor_field: Optional[List[str]],
slice_boundary_fields: Optional[Tuple[str, str]],
):
"""
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.

:param stream: The stream to delegate to for partition generation.
:param message_repository: The message repository to use to emit non-record messages.
:param sync_mode: The synchronization mode.
:param cursor: A Cursor object that maintains the state and the cursor field.
"""
self._stream = stream
self.message_repository = message_repository
self._sync_mode = SyncMode.full_refresh
self._cursor = cursor
self._cursor_field = cursor_field
self._state = self._cursor.state
self._slice_boundary_fields = slice_boundary_fields
self._connector_state_converter = connector_state_converter

def generate(self) -> Iterable[Partition]:
"""
Generate partitions based on the slices in the cursor's state.

This method iterates through the list of slices found in the cursor's state, and for each slice, it generates
a `StreamPartition` object.

:return: An iterable of StreamPartition objects.
"""

start_boundary = (
self._slice_boundary_fields[self._START_BOUNDARY]
if self._slice_boundary_fields
else "start"
)
end_boundary = (
self._slice_boundary_fields[self._END_BOUNDARY]
if self._slice_boundary_fields
else "end"
)

for slice_start, slice_end in self._cursor.generate_slices():
stream_slice = StreamSlice(
partition={},
cursor_slice={
start_boundary: self._connector_state_converter.output_format(slice_start),
end_boundary: self._connector_state_converter.output_format(slice_end),
},
)

yield StreamPartition(
self._stream,
copy.deepcopy(stream_slice),
self.message_repository,
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


Expand Down
Loading
Loading