Skip to content

Commit

Permalink
feat(Low-Code Concurrent CDK): Make SimpleRetriever thread-safe so th…
Browse files Browse the repository at this point in the history
…at different partitions can share the same SimpleRetriever (#185)
  • Loading branch information
brianjlai authored Jan 9, 2025
1 parent 6d5ce67 commit 0e7802a
Show file tree
Hide file tree
Showing 22 changed files with 635 additions and 435 deletions.
75 changes: 24 additions & 51 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import logging
from typing import Any, Callable, Generic, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple

from airbyte_cdk.models import (
AirbyteCatalog,
Expand All @@ -28,15 +28,11 @@
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
DatetimeBasedCursor as DatetimeBasedCursorModel,
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
DeclarativeStream as DeclarativeStreamModel,
)
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
ComponentDefinition,
ModelToComponentFactory,
)
from airbyte_cdk.sources.declarative.requesters import HttpRequester
from airbyte_cdk.sources.declarative.retrievers import Retriever, SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartitionFactory,
StreamSlicerPartitionGenerator,
Expand All @@ -52,7 +48,6 @@
from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream
from airbyte_cdk.sources.types import Config, StreamState


class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]):
Expand Down Expand Up @@ -194,10 +189,11 @@ def _group_streams(
# Some low-code sources use a combination of DeclarativeStream and regular Python streams. We can't inspect
# these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible,
# so we need to treat them as synchronous
if (
isinstance(declarative_stream, DeclarativeStream)
and name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
if isinstance(declarative_stream, DeclarativeStream) and (
name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
== "SimpleRetriever"
or name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
== "AsyncRetriever"
):
incremental_sync_component_definition = name_to_stream_mapping[
declarative_stream.name
Expand Down Expand Up @@ -234,15 +230,27 @@ def _group_streams(
stream_state=stream_state,
)

retriever = declarative_stream.retriever

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if declarative_stream and isinstance(retriever, SimpleRetriever):
# Also a temporary hack. In the legacy Stream implementation, as part of the read,
# set_initial_state() is called to instantiate incoming state on the cursor. Although we no
# longer rely on the legacy low-code cursor for concurrent checkpointing, low-code components
# like StopConditionPaginationStrategyDecorator and ClientSideIncrementalRecordFilterDecorator
# still rely on a DatetimeBasedCursor that is properly initialized with state.
if retriever.cursor:
retriever.cursor.set_initial_state(stream_state=stream_state)
# We zero it out here, but since this is a cursor reference, the state is still properly
# instantiated for the other components that reference it
retriever.cursor = None

partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
stream_state,
),
retriever,
self.message_repository,
),
cursor,
Expand Down Expand Up @@ -272,11 +280,7 @@ def _group_streams(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
{},
),
declarative_stream.retriever,
self.message_repository,
),
declarative_stream.retriever.stream_slicer,
Expand Down Expand Up @@ -415,34 +419,3 @@ def _remove_concurrent_streams_from_catalog(
if stream.stream.name not in concurrent_stream_names
]
)

def _retriever_factory(
self, stream_config: ComponentDefinition, source_config: Config, stream_state: StreamState
) -> Callable[[], Retriever]:
def _factory_method() -> Retriever:
declarative_stream: DeclarativeStream = self._constructor.create_component(
DeclarativeStreamModel,
stream_config,
source_config,
emit_connector_builder_messages=self._emit_connector_builder_messages,
)

# This is an optimization so that we don't invoke any cursor or state management flows within the
# low-code framework because state management is handled through the ConcurrentCursor.
if (
declarative_stream
and declarative_stream.retriever
and isinstance(declarative_stream.retriever, SimpleRetriever)
):
# Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is
# called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor
# for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and
# ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized
# with state.
if declarative_stream.retriever.cursor:
declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state)
declarative_stream.retriever.cursor = None

return declarative_stream.retriever

return _factory_method
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,39 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
)
if isinstance(self.url_base, str):
self.url_base = InterpolatedString(string=self.url_base, parameters=parameters)
self._token: Optional[Any] = self.pagination_strategy.initial_token

def get_initial_token(self) -> Optional[Any]:
"""
Return the page token that should be used for the first request of a stream
WARNING: get_initial_token() should not be used by streams that use RFR that perform checkpointing
of state using page numbers. Because paginators are stateless
"""
return self.pagination_strategy.initial_token

def next_page_token(
self, response: requests.Response, last_page_size: int, last_record: Optional[Record]
self,
response: requests.Response,
last_page_size: int,
last_record: Optional[Record],
last_page_token_value: Optional[Any] = None,
) -> Optional[Mapping[str, Any]]:
self._token = self.pagination_strategy.next_page_token(
response, last_page_size, last_record
next_page_token = self.pagination_strategy.next_page_token(
response=response,
last_page_size=last_page_size,
last_record=last_record,
last_page_token_value=last_page_token_value,
)
if self._token:
return {"next_page_token": self._token}
if next_page_token:
return {"next_page_token": next_page_token}
else:
return None

def path(self) -> Optional[str]:
if (
self._token
and self.page_token_option
and isinstance(self.page_token_option, RequestPath)
):
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
token = next_page_token.get("next_page_token") if next_page_token else None
if token and self.page_token_option and isinstance(self.page_token_option, RequestPath):
# Replace url base to only return the path
return str(self._token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__
return str(token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__
else:
return None

Expand All @@ -143,7 +155,7 @@ def get_request_params(
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter)
return self._get_request_options(RequestOptionType.request_parameter, next_page_token)

def get_request_headers(
self,
Expand All @@ -152,7 +164,7 @@ def get_request_headers(
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, str]:
return self._get_request_options(RequestOptionType.header)
return self._get_request_options(RequestOptionType.header, next_page_token)

def get_request_body_data(
self,
Expand All @@ -161,7 +173,7 @@ def get_request_body_data(
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_data)
return self._get_request_options(RequestOptionType.body_data, next_page_token)

def get_request_body_json(
self,
Expand All @@ -170,25 +182,21 @@ def get_request_body_json(
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json)

def reset(self, reset_value: Optional[Any] = None) -> None:
if reset_value:
self.pagination_strategy.reset(reset_value=reset_value)
else:
self.pagination_strategy.reset()
self._token = self.pagination_strategy.initial_token
return self._get_request_options(RequestOptionType.body_json, next_page_token)

def _get_request_options(self, option_type: RequestOptionType) -> MutableMapping[str, Any]:
def _get_request_options(
self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]]
) -> MutableMapping[str, Any]:
options = {}

token = next_page_token.get("next_page_token") if next_page_token else None
if (
self.page_token_option
and self._token is not None
and token is not None
and isinstance(self.page_token_option, RequestOption)
and self.page_token_option.inject_into == option_type
):
options[self.page_token_option.field_name.eval(config=self.config)] = self._token # type: ignore # field_name is always cast to an interpolated string
options[self.page_token_option.field_name.eval(config=self.config)] = token # type: ignore # field_name is always cast to an interpolated string
if (
self.page_size_option
and self.pagination_strategy.get_page_size()
Expand All @@ -204,6 +212,9 @@ class PaginatorTestReadDecorator(Paginator):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
pages that are queried throughout a read command.
WARNING: This decorator is not currently thread-safe like the rest of the low-code framework because it has
an internal state to track the current number of pages counted so that it can exit early during a test read
"""

_PAGE_COUNT_BEFORE_FIRST_NEXT_CALL = 1
Expand All @@ -217,17 +228,27 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No
self._decorated = decorated
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL

def get_initial_token(self) -> Optional[Any]:
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
return self._decorated.get_initial_token()

def next_page_token(
self, response: requests.Response, last_page_size: int, last_record: Optional[Record]
self,
response: requests.Response,
last_page_size: int,
last_record: Optional[Record],
last_page_token_value: Optional[Any] = None,
) -> Optional[Mapping[str, Any]]:
if self._page_count >= self._maximum_number_of_pages:
return None

self._page_count += 1
return self._decorated.next_page_token(response, last_page_size, last_record)
return self._decorated.next_page_token(
response, last_page_size, last_record, last_page_token_value
)

def path(self) -> Optional[str]:
return self._decorated.path()
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
return self._decorated.path(next_page_token)

def get_request_params(
self,
Expand Down Expand Up @@ -272,7 +293,3 @@ def get_request_body_json(
return self._decorated.get_request_body_json(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)

def reset(self, reset_value: Optional[Any] = None) -> None:
self._decorated.reset()
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class NoPagination(Paginator):

parameters: InitVar[Mapping[str, Any]]

def path(self) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
return None

def get_request_params(
Expand Down Expand Up @@ -58,11 +58,14 @@ def get_request_body_json(
) -> Mapping[str, Any]:
return {}

def get_initial_token(self) -> Optional[Any]:
return None

def next_page_token(
self, response: requests.Response, last_page_size: int, last_record: Optional[Record]
) -> Mapping[str, Any]:
self,
response: requests.Response,
last_page_size: int,
last_record: Optional[Record],
last_page_token_value: Optional[Any],
) -> Optional[Mapping[str, Any]]:
return {}

def reset(self, reset_value: Optional[Any] = None) -> None:
# No state to reset
pass
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,32 @@ class Paginator(ABC, RequestOptionsProvider):
"""

@abstractmethod
def reset(self, reset_value: Optional[Any] = None) -> None:
def get_initial_token(self) -> Optional[Any]:
"""
Reset the pagination's inner state
Get the page token that should be included in the request to get the first page of records
"""

@abstractmethod
def next_page_token(
self, response: requests.Response, last_page_size: int, last_record: Optional[Record]
self,
response: requests.Response,
last_page_size: int,
last_record: Optional[Record],
last_page_token_value: Optional[Any],
) -> Optional[Mapping[str, Any]]:
"""
Returns the next_page_token to use to fetch the next page of records.
:param response: the response to process
:param last_page_size: the number of records read from the response
:param last_record: the last record extracted from the response
:param last_page_token_value: The current value of the page token made on the last request
:return: A mapping {"next_page_token": <token>} for the next page from the input response object. Returning None means there are no more pages to read in this response.
"""
pass

@abstractmethod
def path(self) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
"""
Returns the URL path to hit to fetch the next page of records
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class CursorPaginationStrategy(PaginationStrategy):
)

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._initial_cursor = None
if isinstance(self.cursor_value, str):
self._cursor_value = InterpolatedString.create(self.cursor_value, parameters=parameters)
else:
Expand All @@ -57,10 +56,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:

@property
def initial_token(self) -> Optional[Any]:
return self._initial_cursor
"""
CursorPaginationStrategy does not have an initial value because the next cursor is typically included
in the response of the first request. For Resumable Full Refresh streams that checkpoint the page
cursor, the next cursor should be read from the state or stream slice object.
"""
return None

def next_page_token(
self, response: requests.Response, last_page_size: int, last_record: Optional[Record]
self,
response: requests.Response,
last_page_size: int,
last_record: Optional[Record],
last_page_token_value: Optional[Any] = None,
) -> Optional[Any]:
decoded_response = next(self.decoder.decode(response))

Expand All @@ -87,8 +95,5 @@ def next_page_token(
)
return token if token else None

def reset(self, reset_value: Optional[Any] = None) -> None:
self._initial_cursor = reset_value

def get_page_size(self) -> Optional[int]:
return self.page_size
Loading

0 comments on commit 0e7802a

Please sign in to comment.