diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 15b5cfbce..749910d69 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -1,10 +1,10 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +import logging from dataclasses import InitVar, dataclass, field from datetime import datetime, timedelta -from typing import Any, List, Mapping, MutableMapping, Optional, Union +from typing import Any, List, Mapping, Optional, Union from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean @@ -19,6 +19,8 @@ ) from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse +logger = logging.getLogger("airbyte") + @dataclass class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAuthenticator): @@ -201,7 +203,8 @@ def get_client_secret(self) -> str: self._client_secret.eval(self.config) if self._client_secret else self._client_secret ) if not client_secret: - raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter") + # We've seen some APIs allowing empty client_secret so we will only log here + logger.warning("OAuthAuthenticator was unable to evaluate client_secret parameter hence it'll be empty") return client_secret # type: ignore # value will be returned as a string, or an error will be raised def get_refresh_token_name(self) -> str: diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index ba76b0a39..5d18f1ee9 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -88,6 +88,7 @@ def __init__( emit_connector_builder_messages=emit_connector_builder_messages, disable_resumable_full_refresh=True, connector_state_manager=self._connector_state_manager, + max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"), ) super().__init__( diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index e8bca6cc9..201f163fa 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -320,14 +320,13 @@ def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str return polling_response_context def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice: - stream_slice = StreamSlice( - partition={}, - cursor_slice={}, - extra_fields={ + return StreamSlice( + partition=job.job_parameters().partition, + cursor_slice=job.job_parameters().cursor_slice, + extra_fields=dict(job.job_parameters().extra_fields) | { "creation_response": self._get_creation_response_interpolation_context(job), }, ) - return stream_slice def _get_download_targets(self, job: AsyncJob) -> Iterable[str]: if not self.download_target_requester: diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py index 4e175bb28..a363f9edc 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py @@ -11,6 +11,7 @@ ) from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.types import Config, StreamSlice +from airbyte_cdk.utils.mapping_helpers import get_interpolation_context @dataclass @@ -52,8 +53,8 @@ def eval_request_inputs( :param next_page_token: The pagination token :return: The request inputs to set on an outgoing HTTP request """ - kwargs = { - "stream_slice": stream_slice, - "next_page_token": next_page_token, - } + kwargs = get_interpolation_context( + stream_slice=stream_slice, + next_page_token=next_page_token, + ) return self._interpolator.eval(self.config, **kwargs) # type: ignore # self._interpolator is always initialized with a value and will not be None diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py index ed0e54c60..dfe8d6460 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py @@ -8,6 +8,7 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.types import Config, StreamSlice, StreamState +from airbyte_cdk.utils.mapping_helpers import get_interpolation_context @dataclass @@ -51,10 +52,10 @@ def eval_request_inputs( :param valid_value_types: A tuple of types that the interpolator should allow :return: The request inputs to set on an outgoing HTTP request """ - kwargs = { - "stream_slice": stream_slice, - "next_page_token": next_page_token, - } + kwargs = get_interpolation_context( + stream_slice=stream_slice, + next_page_token=next_page_token, + ) interpolated_value = self._interpolator.eval( # type: ignore # self._interpolator is always initialized with a value and will not be None self.config, valid_key_types=valid_key_types, diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index 077aa4573..bc616e5d2 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -5,6 +5,7 @@ import base64 import json import logging +from copy import deepcopy from datetime import timedelta, timezone from unittest.mock import Mock @@ -128,6 +129,20 @@ def test_refresh_with_encode_config_params(self): } assert body == expected + def test_client_secret_empty(self): + config_without_client_secret = deepcopy(config) + del config_without_client_secret["client_secret"] + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + config=config_without_client_secret, + parameters={}, + grant_type="client_credentials", + ) + body = oauth.build_refresh_request_body() + assert body["client_secret"] == "" + def test_refresh_with_decode_config_params(self): updated_config_fields = { "client_id": base64.b64encode(config["client_id"].encode("utf-8")).decode(), diff --git a/unit_tests/sources/declarative/parsers/conftest.py b/unit_tests/sources/declarative/parsers/conftest.py index 3f653ebb1..a51fe4b4e 100644 --- a/unit_tests/sources/declarative/parsers/conftest.py +++ b/unit_tests/sources/declarative/parsers/conftest.py @@ -1,7 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +from pathlib import Path from typing import Any, Dict import pytest @@ -645,10 +645,7 @@ def expected_manifest_with_url_base_linked_definition_normalized() -> Dict[str, @pytest.fixture def manifest_with_linked_definitions_url_base_authenticator_abnormal_schemas() -> Dict[str, Any]: - with open( - "unit_tests/sources/declarative/parsers/resources/abnormal_schemas_manifest.yaml", - "r", - ) as file: + with open(str(Path(__file__).parent / "resources/abnormal_schemas_manifest.yaml"), "r") as file: return dict(yaml.safe_load(file)) diff --git a/unit_tests/sources/declarative/requesters/test_http_job_repository.py b/unit_tests/sources/declarative/requesters/test_http_job_repository.py index 4be3ecb11..5e3d6b199 100644 --- a/unit_tests/sources/declarative/requesters/test_http_job_repository.py +++ b/unit_tests/sources/declarative/requesters/test_http_job_repository.py @@ -2,6 +2,8 @@ import json +from turtledemo.sorting_animate import partition +from typing import Optional from unittest import TestCase from unittest.mock import Mock @@ -28,6 +30,8 @@ ) from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -45,112 +49,12 @@ a_record_id,a_value """ _A_CURSOR_FOR_PAGINATION = "a-cursor-for-pagination" +_ERROR_HANDLER = DefaultErrorHandler(config=_ANY_CONFIG, parameters={}) class HttpJobRepositoryTest(TestCase): def setUp(self) -> None: - message_repository = Mock() - error_handler = DefaultErrorHandler(config=_ANY_CONFIG, parameters={}) - - self._create_job_requester = HttpRequester( - name="stream : create_job", - url_base=_URL_BASE, - path=_EXPORT_PATH, - error_handler=error_handler, - http_method=HttpMethod.POST, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=False, - ) - - self._polling_job_requester = HttpRequester( - name="stream : polling", - url_base=_URL_BASE, - path=_EXPORT_PATH + "/{{creation_response['id']}}", - error_handler=error_handler, - http_method=HttpMethod.GET, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=False, - ) - - self._download_retriever = SimpleRetriever( - requester=HttpRequester( - name="stream : fetch_result", - url_base="", - path="{{download_target}}", - error_handler=error_handler, - http_method=HttpMethod.GET, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=True, - ), - record_selector=RecordSelector( - extractor=ResponseToFileExtractor({}), - record_filter=None, - transformations=[], - schema_normalization=TypeTransformer(TransformConfig.NoTransform), - config=_ANY_CONFIG, - parameters={}, - ), - primary_key=None, - name="any name", - paginator=DefaultPaginator( - decoder=NoopDecoder(), - page_size_option=None, - page_token_option=RequestOption( - field_name="locator", - inject_into=RequestOptionType.request_parameter, - parameters={}, - ), - pagination_strategy=CursorPaginationStrategy( - cursor_value="{{ headers['Sforce-Locator'] }}", - decoder=NoopDecoder(), - config=_ANY_CONFIG, - parameters={}, - ), - url_base=_URL_BASE, - config=_ANY_CONFIG, - parameters={}, - ), - config=_ANY_CONFIG, - parameters={}, - ) - - self._repository = AsyncHttpJobRepository( - creation_requester=self._create_job_requester, - polling_requester=self._polling_job_requester, - download_retriever=self._download_retriever, - abort_requester=None, - delete_requester=None, - status_extractor=DpathExtractor( - decoder=JsonDecoder(parameters={}), - field_path=["status"], - config={}, - parameters={} or {}, - ), - status_mapping={ - "ready": AsyncJobStatus.COMPLETED, - "failure": AsyncJobStatus.FAILED, - "pending": AsyncJobStatus.RUNNING, - }, - download_target_extractor=DpathExtractor( - decoder=JsonDecoder(parameters={}), - field_path=["urls"], - config={}, - parameters={} or {}, - ), - ) - + self._repository = self._create_async_job_repository() self._http_mocker = HttpMocker() self._http_mocker.__enter__() @@ -178,6 +82,32 @@ def test_given_different_statuses_when_update_jobs_status_then_update_status_pro self._repository.update_jobs_status([job]) assert job.status() == AsyncJobStatus.COMPLETED + def test_when_update_jobs_status_then_allow_access_to_stream_slice_information(self) -> None: + stream_slice = StreamSlice(partition={"path": "path_from_slice"}, cursor_slice={}) + self._mock_create_response(_A_JOB_ID) + self._http_mocker.get( + HttpRequest(url=f"{_EXPORT_URL}/{stream_slice['path']}/{_A_JOB_ID}"), + HttpResponse(body=json.dumps({"id": _A_JOB_ID, "status": "ready"})), + ) + repository = self._create_async_job_repository(HttpRequester( + name="stream : polling", + url_base=_URL_BASE, + path=_EXPORT_PATH + "/{{stream_slice['path']}}/{{creation_response['id']}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=Mock(), # this might not align with the rest of the components in async job repository but if message_repository becomes important for tests, please share this instance with the other components + use_cache=False, + stream_response=False, + )) + + job = repository.start(stream_slice) + repository.update_jobs_status([job]) + + assert job.status() == AsyncJobStatus.COMPLETED + def test_given_unknown_status_when_update_jobs_status_then_raise_error(self) -> None: self._mock_create_response(_A_JOB_ID) self._http_mocker.get( @@ -277,3 +207,103 @@ def _mock_create_response(self, job_id: str) -> None: HttpRequest(url=_EXPORT_URL), HttpResponse(body=json.dumps({"id": job_id})), ) + + def _create_async_job_repository(self, polling_job_requester: Optional[HttpRequester] = None) -> AsyncHttpJobRepository: + message_repository = Mock() + create_job_requester = HttpRequester( + name="stream : create_job", + url_base=_URL_BASE, + path=_EXPORT_PATH, + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.POST, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=False, + ) + polling_job_requester = polling_job_requester if polling_job_requester else HttpRequester( + name="stream : polling", + url_base=_URL_BASE, + path=_EXPORT_PATH + "/{{creation_response['id']}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=False, + ) + + download_retriever = SimpleRetriever( + requester=HttpRequester( + name="stream : fetch_result", + url_base="", + path="{{download_target}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=True, + ), + record_selector=RecordSelector( + extractor=ResponseToFileExtractor({}), + record_filter=None, + transformations=[], + schema_normalization=TypeTransformer(TransformConfig.NoTransform), + config=_ANY_CONFIG, + parameters={}, + ), + primary_key=None, + name="any name", + paginator=DefaultPaginator( + decoder=NoopDecoder(), + page_size_option=None, + page_token_option=RequestOption( + field_name="locator", + inject_into=RequestOptionType.request_parameter, + parameters={}, + ), + pagination_strategy=CursorPaginationStrategy( + cursor_value="{{ headers['Sforce-Locator'] }}", + decoder=NoopDecoder(), + config=_ANY_CONFIG, + parameters={}, + ), + url_base=_URL_BASE, + config=_ANY_CONFIG, + parameters={}, + ), + config=_ANY_CONFIG, + parameters={}, + ) + + return AsyncHttpJobRepository( + creation_requester=create_job_requester, + polling_requester=polling_job_requester, + download_retriever=download_retriever, + abort_requester=None, + delete_requester=None, + status_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["status"], + config={}, + parameters={} or {}, + ), + status_mapping={ + "ready": AsyncJobStatus.COMPLETED, + "failure": AsyncJobStatus.FAILED, + "pending": AsyncJobStatus.RUNNING, + }, + download_target_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["urls"], + config={}, + parameters={} or {}, + ), + )