Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(refactor): refactor partition generator to take any stream slicer #39

Merged
merged 17 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def on_partition_complete_sentinel(

try:
if sentinel.is_successful:
partition.close()
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
Expand Down
80 changes: 52 additions & 28 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union, Callable
maxi297 marked this conversation as resolved.
Show resolved Hide resolved

from airbyte_cdk.models import (
AirbyteCatalog,
Expand All @@ -24,18 +24,24 @@
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
DatetimeBasedCursor as DatetimeBasedCursorModel,
DeclarativeStream as DeclarativeStreamModel,
)
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
ModelToComponentFactory,
ComponentDefinition,
)
from airbyte_cdk.sources.declarative.requesters import HttpRequester
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, Retriever
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartitionFactory,
StreamSlicerPartitionGenerator,
)
from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields
from airbyte_cdk.sources.declarative.types import ConnectionDefinition
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.types import Config, StreamState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AlwaysAvailableAvailabilityStrategy,
)
Expand Down Expand Up @@ -210,31 +216,18 @@ def _group_streams(
)
)

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if (
declarative_stream
and declarative_stream.retriever
and isinstance(declarative_stream.retriever, SimpleRetriever)
):
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
# with state.
if declarative_stream.retriever.cursor:
declarative_stream.retriever.cursor.set_initial_state(
stream_state=stream_state
)
declarative_stream.retriever.cursor = None

partition_generator = CursorPartitionGenerator(
stream=declarative_stream,
message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory
cursor=cursor,
connector_state_converter=connector_state_converter,
cursor_field=[cursor.cursor_field.cursor_field_key],
slice_boundary_fields=cursor.slice_boundary_fields,
partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
stream_state,
),
self.message_repository,
),
cursor,
)

concurrent_streams.append(
Expand Down Expand Up @@ -344,3 +337,34 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

def _retriever_factory(
self, stream_config: ComponentDefinition, source_config: Config, stream_state: StreamState
) -> Callable[[], Retriever]:
def _factory_method() -> Retriever:
declarative_stream: DeclarativeStream = self._constructor.create_component(
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
DeclarativeStreamModel,
stream_config,
source_config,
emit_connector_builder_messages=self._emit_connector_builder_messages,
)

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if (
declarative_stream
and declarative_stream.retriever
and isinstance(declarative_stream.retriever, SimpleRetriever)
):
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
# with state.
if declarative_stream.retriever.cursor:
declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state)
declarative_stream.retriever.cursor = None

return declarative_stream.retriever

return _factory_method
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from copy import deepcopy
from importlib import metadata
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple

import yaml
from airbyte_cdk.models import (
Expand Down Expand Up @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]:
return self._source_config

@property
def message_repository(self) -> Union[None, MessageRepository]:
def message_repository(self) -> MessageRepository:
return self._message_repository

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
maxi297 marked this conversation as resolved.
Show resolved Hide resolved

from typing import Iterable, Optional, Mapping, Any, Callable

from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


class DeclarativePartitionFactory:
def __init__(
self,
stream_name: str,
json_schema: Mapping[str, Any],
retriever_factory: Callable[[], Retriever],
message_repository: MessageRepository,
) -> None:
"""
The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that out components are not
thread safe and classes like `DefaultPaginator` would not behave the way we want if multiple threads were to call their methods.
In order to avoid these problems, we will create one retriever per thread which should make the processing thread-safe.
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
"""
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever_factory = retriever_factory
self._message_repository = message_repository

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
self._stream_name,
self._json_schema,
self._retriever_factory(),
self._message_repository,
stream_slice,
)


class DeclarativePartition(Partition):
def __init__(
self,
stream_name: str,
json_schema: Mapping[str, Any],
retriever: Retriever,
message_repository: MessageRepository,
stream_slice: StreamSlice,
):
self._stream_name = stream_name
self._json_schema = json_schema
self._retriever = retriever
self._message_repository = message_repository
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)

def read(self) -> Iterable[Record]:
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
if isinstance(stream_data, Mapping):
yield Record(stream_data, self)
else:
self._message_repository.emit_message(stream_data)

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice

def stream_name(self) -> str:
return self._stream_name

def __hash__(self) -> int:
return self._hash


class StreamSlicerPartitionGenerator(PartitionGenerator):
def __init__(
self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer
) -> None:
self._partition_factory = partition_factory
self._stream_slicer = stream_slicer

def generate(self) -> Iterable[Partition]:
for stream_slice in self._stream_slicer.stream_slices():
yield self._partition_factory.create(stream_slice)
19 changes: 6 additions & 13 deletions airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable
from abc import ABC

from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
RequestOptionsProvider,
)
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import (
StreamSlicer as ConcurrentStreamSlicer,
)


@dataclass
class StreamSlicer(RequestOptionsProvider):
class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC):
"""
Slices the stream into a subset of records.
Slices enable state checkpointing and data retrieval parallelization.
Expand All @@ -23,10 +22,4 @@ class StreamSlicer(RequestOptionsProvider):
See the stream slicing section of the docs for more information.
"""

@abstractmethod
def stream_slices(self) -> Iterable[StreamSlice]:
"""
Defines stream slices

:return: List of stream slices
"""
pass
11 changes: 0 additions & 11 deletions airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,13 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._slice = _slice
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
try:
Expand Down Expand Up @@ -289,13 +286,6 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
file = self._slice["files"][0]
return {"files": [file]}

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
Expand Down Expand Up @@ -352,7 +342,6 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
)
self._cursor.set_pending_partitions(pending_partitions)
Expand Down
Loading
Loading