Skip to content

Commit

Permalink
chore(refactor): refactor partition generator to take any stream slic…
Browse files Browse the repository at this point in the history
…er (#39)

Co-authored-by: Aaron ("AJ") Steers <[email protected]>
Co-authored-by: octavia-squidington-iii <[email protected]>
Co-authored-by: Brian Lai <[email protected]>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent e808271 commit e27cb81
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 295 deletions.
80 changes: 52 additions & 28 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union, Callable

from airbyte_cdk.models import (
AirbyteCatalog,
Expand All @@ -27,18 +27,24 @@
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
DatetimeBasedCursor as DatetimeBasedCursorModel,
DeclarativeStream as DeclarativeStreamModel,
)
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
ModelToComponentFactory,
ComponentDefinition,
)
from airbyte_cdk.sources.declarative.requesters import HttpRequester
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, Retriever
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartitionFactory,
StreamSlicerPartitionGenerator,
)
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.types import Config, StreamState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AlwaysAvailableAvailabilityStrategy,
)
Expand Down Expand Up @@ -213,31 +219,18 @@ def _group_streams(
)
)

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if (
declarative_stream
and declarative_stream.retriever
and isinstance(declarative_stream.retriever, SimpleRetriever)
):
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
# with state.
if declarative_stream.retriever.cursor:
declarative_stream.retriever.cursor.set_initial_state(
stream_state=stream_state
)
declarative_stream.retriever.cursor = None

partition_generator = CursorPartitionGenerator(
stream=declarative_stream,
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
cursor=cursor,
connector_state_converter=connector_state_converter,
cursor_field=[cursor.cursor_field.cursor_field_key],
slice_boundary_fields=cursor.slice_boundary_fields,
partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
stream_state,
),
self.message_repository,
),
cursor,
)

concurrent_streams.append(
Expand Down Expand Up @@ -350,3 +343,34 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

def _retriever_factory(
self, stream_config: ComponentDefinition, source_config: Config, stream_state: StreamState
) -> Callable[[], Retriever]:
def _factory_method() -> Retriever:
declarative_stream: DeclarativeStream = self._constructor.create_component(
DeclarativeStreamModel,
stream_config,
source_config,
emit_connector_builder_messages=self._emit_connector_builder_messages,
)

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if (
declarative_stream
and declarative_stream.retriever
and isinstance(declarative_stream.retriever, SimpleRetriever)
):
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
# with state.
if declarative_stream.retriever.cursor:
declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state)
declarative_stream.retriever.cursor = None

return declarative_stream.retriever

return _factory_method
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from copy import deepcopy
from importlib import metadata
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple

import yaml
from airbyte_cdk.models import (
Expand Down Expand Up @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
return self._source_config

@property
def message_repository(self) -> Union[None, MessageRepository]:
def message_repository(self) -> MessageRepository:
return self._message_repository

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from typing import Iterable, Optional, Mapping, Any, Callable

from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


class DeclarativePartitionFactory:
def __init__(
self,
stream_name: str,
json_schema: Mapping[str, Any],
retriever_factory: Callable[[], Retriever],
message_repository: MessageRepository,
) -> None:
"""
The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not
thread safe and classes like `DefaultPaginator` may not work because multiple threads can access and modify a shared field across each other.
In order to avoid these problems, we will create one retriever per thread which should make the processing thread-safe.
"""
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever_factory = retriever_factory
self._message_repository = message_repository

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
self._stream_name,
self._json_schema,
self._retriever_factory(),
self._message_repository,
stream_slice,
)


class DeclarativePartition(Partition):
def __init__(
self,
stream_name: str,
json_schema: Mapping[str, Any],
retriever: Retriever,
message_repository: MessageRepository,
stream_slice: StreamSlice,
):
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)

def read(self) -> Iterable[Record]:
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
if isinstance(stream_data, Mapping):
yield Record(stream_data, self)
else:
self._message_repository.emit_message(stream_data)

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice

def stream_name(self) -> str:
return self._stream_name

def __hash__(self) -> int:
return self._hash


class StreamSlicerPartitionGenerator(PartitionGenerator):
def __init__(
self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer
) -> None:
self._partition_factory = partition_factory
self._stream_slicer = stream_slicer

def generate(self) -> Iterable[Partition]:
for stream_slice in self._stream_slicer.stream_slices():
yield self._partition_factory.create(stream_slice)
19 changes: 6 additions & 13 deletions airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable
from abc import ABC

from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
RequestOptionsProvider,
)
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import (
StreamSlicer as ConcurrentStreamSlicer,
)


@dataclass
class StreamSlicer(RequestOptionsProvider):
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
"""
Slices the stream into a subset of records.
Slices enable state checkpointing and data retrieval parallelization.
Expand All @@ -23,10 +22,4 @@ class StreamSlicer(RequestOptionsProvider):
See the stream slicing section of the docs for more information.
"""

@abstractmethod
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Defines stream slices
:return: List of stream slices
"""
pass
91 changes: 4 additions & 87 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
DateTimeStreamStateConverter,
)
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from deprecated.classic import deprecated

from airbyte_cdk.utils.slice_hasher import SliceHasher

"""
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
"""
Expand Down Expand Up @@ -270,6 +268,7 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._hash = SliceHasher.hash(self._stream.name, self._slice)

def read(self) -> Iterable[Record]:
"""
Expand Down Expand Up @@ -309,12 +308,7 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._slice

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder)
return hash((self._stream.name, s))
else:
return hash(self._stream.name)
return self._hash

def stream_name(self) -> str:
return self._stream.name
Expand Down Expand Up @@ -363,83 +357,6 @@ def generate(self) -> Iterable[Partition]:
)


class CursorPartitionGenerator(PartitionGenerator):
"""
This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions.
It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained
across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state.
"""

_START_BOUNDARY = 0
_END_BOUNDARY = 1

def __init__(
self,
stream: Stream,
message_repository: MessageRepository,
cursor: Cursor,
connector_state_converter: DateTimeStreamStateConverter,
cursor_field: Optional[List[str]],
slice_boundary_fields: Optional[Tuple[str, str]],
):
"""
Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor.
:param stream: The stream to delegate to for partition generation.
:param message_repository: The message repository to use to emit non-record messages.
:param sync_mode: The synchronization mode.
:param cursor: A Cursor object that maintains the state and the cursor field.
"""
self._stream = stream
self.message_repository = message_repository
self._sync_mode = SyncMode.full_refresh
self._cursor = cursor
self._cursor_field = cursor_field
self._state = self._cursor.state
self._slice_boundary_fields = slice_boundary_fields
self._connector_state_converter = connector_state_converter

def generate(self) -> Iterable[Partition]:
"""
Generate partitions based on the slices in the cursor's state.
This method iterates through the list of slices found in the cursor's state, and for each slice, it generates
a `StreamPartition` object.
:return: An iterable of StreamPartition objects.
"""

start_boundary = (
self._slice_boundary_fields[self._START_BOUNDARY]
if self._slice_boundary_fields
else "start"
)
end_boundary = (
self._slice_boundary_fields[self._END_BOUNDARY]
if self._slice_boundary_fields
else "end"
)

for slice_start, slice_end in self._cursor.generate_slices():
stream_slice = StreamSlice(
partition={},
cursor_slice={
start_boundary: self._connector_state_converter.output_format(slice_start),
end_boundary: self._connector_state_converter.output_format(slice_end),
},
)

yield StreamPartition(
self._stream,
copy.deepcopy(stream_slice),
self.message_repository,
self._sync_mode,
self._cursor_field,
self._state,
)


@deprecated(
"Availability strategy has been soft deprecated. Do not use. Class is subject to removal",
category=ExperimentalClassWarning,
Expand Down
Loading

0 comments on commit e27cb81

Please sign in to comment.