Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(airbyte-cdk): replace pydantic BaseModel with dataclasses in protocol #44026

Closed
wants to merge 9 commits into from
Closed
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification:
else:
raise FileNotFoundError("Unable to find spec.yaml or spec.json in the package.")

return ConnectorSpecification.parse_obj(spec_obj)
return ConnectorSpecification(spec_obj)

@abstractmethod
def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionStatus:
Expand Down
4 changes: 2 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from requests import PreparedRequest, Response, Session

from orjson import orjson
logger = init_logger("airbyte")

VALID_URL_SCHEMES = ["https"]
Expand Down Expand Up @@ -200,7 +200,7 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str,

@staticmethod
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> Any:
return airbyte_message.model_dump_json(exclude_unset=True)
return orjson.dumps(airbyte_message).decode()

@classmethod
def extract_state(cls, args: List[str]) -> Optional[Any]:
Expand Down
4 changes: 3 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from orjson import orjson

LOGGING_CONFIG = {
"version": 1,
Expand Down Expand Up @@ -60,7 +61,8 @@ def format(self, record: logging.LogRecord) -> str:
message = super().format(record)
message = filter_secrets(message)
log_message = AirbyteMessage(type="LOG", log=AirbyteLogMessage(level=airbyte_level, message=message))
return log_message.model_dump_json(exclude_unset=True) # type: ignore
return orjson.dumps(log_message).decode()


@staticmethod
def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]:
Expand Down
18 changes: 17 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/models/airbyte_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,20 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from airbyte_protocol.models import *
from airbyte_protocol_dataclasses.models import *


from dataclasses import dataclass, InitVar
from typing import Mapping



@dataclass
class AirbyteStateBlob:
kwargs: InitVar[Mapping[str, Any]]

def __post_init__(self, kwargs):
self.__dict__.update(kwargs)


from airbyte_protocol.models import ConfiguredAirbyteCatalog
102 changes: 101 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/models/well_known_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,105 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
#
# from airbyte_protocol.models.well_known_types import *

# generated by datamodel-codegen:
# filename: well_known_types.yaml

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional


@dataclass
class String:
type: str
description: str


@dataclass
class BinaryData:
type: str
description: str
pattern: str


@dataclass
class Date:
type: str
pattern: str
description: str


@dataclass
class TimestampWithTimezone:
type: str
pattern: str
description: str


@dataclass
class TimestampWithoutTimezone:
type: str
pattern: str
description: str


@dataclass
class TimeWithTimezone:
type: str
pattern: str
description: str


@dataclass
class TimeWithoutTimezone:
type: str
pattern: str
description: str


@dataclass
class OneOfItem:
pattern: Optional[str] = None
enum: Optional[List[str]] = None


@dataclass
class Number:
type: str
oneOf: List[OneOfItem]
description: str


@dataclass
class Integer:
type: str
oneOf: List[OneOfItem]


@dataclass
class Boolean:
type: str
description: str


@dataclass
class Definitions:
String: String
BinaryData: BinaryData
Date: Date
TimestampWithTimezone: TimestampWithTimezone
TimestampWithoutTimezone: TimestampWithoutTimezone
TimeWithTimezone: TimeWithTimezone
TimeWithoutTimezone: TimeWithoutTimezone
Number: Number
Integer: Integer
Boolean: Boolean


from airbyte_protocol.models.well_known_types import *
@dataclass
class Model:
definitions: Definitions
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union

from dataclasses import dataclass
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateBlob,
Expand All @@ -19,13 +19,15 @@
from pydantic import ConfigDict as V2ConfigDict


class HashableStreamDescriptor(StreamDescriptor):
@dataclass(frozen=True)
class HashableStreamDescriptor:
"""
Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
"""

model_config = V2ConfigDict(extra="allow", frozen=True)
name: str
namespace: Optional[str] = None
# model_config = V2ConfigDict(extra="allow", frozen=True)


class ConnectorStateManager:
Expand Down Expand Up @@ -73,7 +75,7 @@ def update_state_for_stream(self, stream_name: str, namespace: Optional[str], va
:param value: A stream state mapping that is being updated for a stream
"""
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
self.per_stream_states[stream_descriptor] = AirbyteStateBlob.parse_obj(value)
self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)

def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
"""
Expand Down Expand Up @@ -163,7 +165,7 @@ def _create_descriptor_to_stream_state_mapping(
for stream_name, state_value in state.items():
namespace = stream_to_instance_map[stream_name].namespace if stream_name in stream_to_instance_map else None
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
streams[stream_descriptor] = AirbyteStateBlob.parse_obj(state_value or {})
streams[stream_descriptor] = AirbyteStateBlob(state_value or {})
return streams

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def read_state(cls, state_path: str) -> Union[List[AirbyteStateMessage], Mutable
if isinstance(state_obj, List):
parsed_state_messages = []
for state in state_obj: # type: ignore # `isinstance(state_obj, List)` ensures that this is a list
parsed_message = AirbyteStateMessage.parse_obj(state)
parsed_message = AirbyteStateMessage(state)
if not parsed_message.stream and not parsed_message.data and not parsed_message.global_:
raise ValueError("AirbyteStateMessage should contain either a stream, global, or state field")
parsed_state_messages.append(parsed_message)
Expand Down Expand Up @@ -92,7 +92,7 @@ def _emit_legacy_state_format(cls, state_obj: Dict[str, Any]) -> Union[List[Airb
# can be overridden to change an input catalog
@classmethod
def read_catalog(cls, catalog_path: str) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog.parse_obj(cls._read_json_file(catalog_path))
return ConfiguredAirbyteCatalog.model_validate(cls._read_json_file(catalog_path))

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/utils/message_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
from airbyte_protocol.models import AirbyteMessage, Type
from airbyte_cdk.models import AirbyteMessage, Type


def get_stream_descriptor(message: AirbyteMessage) -> HashableStreamDescriptor:
Expand Down
8 changes: 4 additions & 4 deletions airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import time
import traceback
from datetime import datetime
from typing import Optional

from airbyte_cdk.models import (
Expand All @@ -18,6 +17,7 @@
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from orjson import orjson


class AirbyteTracedException(Exception):
Expand Down Expand Up @@ -54,7 +54,7 @@ def as_airbyte_message(self, stream_descriptor: Optional[StreamDescriptor] = Non
:param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many
stream_descriptors are defined, the one from `as_airbyte_message` will be discarded.
"""
now_millis = datetime.now().timestamp() * 1000.0
now_millis = time.time_ns() // 1_000_000

trace_exc = self._exception or self
stack_trace_str = "".join(traceback.TracebackException.from_exception(trace_exc).format())
Expand Down Expand Up @@ -85,7 +85,7 @@ def emit_message(self) -> None:
Prints the exception as an AirbyteTraceMessage.
Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint.
"""
message = self.as_airbyte_message().model_dump_json(exclude_unset=True)
message = orjson.dumps(self.as_airbyte_message()).decode()
filtered_message = filter_secrets(message)
print(filtered_message)

Expand Down
Loading
Loading