Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework record/replay to record at the database connection level. #244

Merged
merged 8 commits into from
Jul 16, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240625-110833.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Rework record/replay to record at the database connection level.
time: 2024-06-25T11:08:33.264457-04:00
custom:
Author: peteralllenwebb
Issue: "244"
67 changes: 0 additions & 67 deletions dbt/adapters/record.py

This file was deleted.

2 changes: 2 additions & 0 deletions dbt/adapters/record/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dbt.adapters.record.handle import RecordReplayHandle
from dbt.adapters.record.cursor.cursor import RecordReplayCursor
54 changes: 54 additions & 0 deletions dbt/adapters/record/cursor/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Optional

from dbt_common.record import record_function

from dbt.adapters.contracts.connection import Connection
from dbt.adapters.record.cursor.description import CursorGetDescriptionRecord
from dbt.adapters.record.cursor.execute import CursorExecuteRecord
from dbt.adapters.record.cursor.fetchone import CursorFetchOneRecord
from dbt.adapters.record.cursor.fetchmany import CursorFetchManyRecord
from dbt.adapters.record.cursor.fetchall import CursorFetchAllRecord
from dbt.adapters.record.cursor.rowcount import CursorGetRowCountRecord


class RecordReplayCursor:
"""A proxy object used to wrap native database cursors under record/replay
modes. In record mode, this proxy notes the parameters and return values
of the methods and properties it implements, which closely match the Python
DB API 2.0 cursor methods used by many dbt adapters to interact with the
database or DWH. In replay mode, it mocks out those calls using previously
recorded calls, so that no interaction with a database actually occurs."""

def __init__(self, native_cursor: Any, connection: Connection) -> None:
self.native_cursor = native_cursor
self.connection = connection

@record_function(CursorExecuteRecord, method=True, id_field_name="connection_name")
def execute(self, operation, parameters=None) -> None:
self.native_cursor.execute(operation, parameters)

@record_function(CursorFetchOneRecord, method=True, id_field_name="connection_name")
def fetchone(self) -> Any:
return self.native_cursor.fetchone()

@record_function(CursorFetchManyRecord, method=True, id_field_name="connection_name")
def fetchmany(self, size: int) -> Any:
return self.native_cursor.fetchmany(size)

@record_function(CursorFetchAllRecord, method=True, id_field_name="connection_name")
def fetchall(self) -> Any:
return self.native_cursor.fetchall()

@property
def connection_name(self) -> Optional[str]:
return self.connection.name

@property
@record_function(CursorGetRowCountRecord, method=True, id_field_name="connection_name")
def rowcount(self) -> int:
return self.native_cursor.rowcount

@property
@record_function(CursorGetDescriptionRecord, method=True, id_field_name="connection_name")
def description(self) -> str:
return self.native_cursor.description
37 changes: 37 additions & 0 deletions dbt/adapters/record/cursor/description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dataclasses
from typing import Any, Iterable, Mapping

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetDescriptionParams:
connection_name: str


@dataclasses.dataclass
class CursorGetDescriptionResult:
columns: Iterable[Any]

def _to_dict(self) -> Any:
column_dicts = []
for c in self.columns:
# This captures the mandatory column information, but we might need
# more for some adapters.
# See https://peps.python.org/pep-0249/#description
column_dicts.append((c[0], c[1]))

return {"columns": column_dicts}

@classmethod
def _from_dict(cls, dct: Mapping) -> "CursorGetDescriptionResult":
return CursorGetDescriptionResult(columns=dct["columns"])


@Recorder.register_record_type
class CursorGetDescriptionRecord(Record):
"""Implements record/replay support for the cursor.description property."""

params_cls = CursorGetDescriptionParams
result_cls = CursorGetDescriptionResult
group = "Database"
20 changes: 20 additions & 0 deletions dbt/adapters/record/cursor/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import dataclasses
from typing import Any, Iterable, Union, Mapping

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorExecuteParams:
connection_name: str
operation: str
parameters: Union[Iterable[Any], Mapping[str, Any]]


@Recorder.register_record_type
class CursorExecuteRecord(Record):
"""Implements record/replay support for the cursor.execute() method."""

params_cls = CursorExecuteParams
result_cls = None
group = "Database"
66 changes: 66 additions & 0 deletions dbt/adapters/record/cursor/fetchall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import dataclasses
import datetime
from typing import Any, Dict, List, Mapping

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorFetchAllParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchAllResult:
results: List[Any]

def _to_dict(self) -> Dict[str, Any]:
processed_results = []
for result in self.results:
result = tuple(map(self._process_value, result))
processed_results.append(result)

return {"results": processed_results}

@classmethod
def _from_dict(cls, dct: Mapping) -> "CursorFetchAllResult":
unprocessed_results = []
for result in dct["results"]:
result = tuple(map(cls._unprocess_value, result))
unprocessed_results.append(result)

return CursorFetchAllResult(unprocessed_results)

@classmethod
def _process_value(cls, value: Any) -> Any:
if type(value) is datetime.date:
return {"type": "date", "value": value.isoformat()}
elif type(value) is datetime.datetime:
return {"type": "datetime", "value": value.isoformat()}
else:
return value

@classmethod
def _unprocess_value(cls, value: Any) -> Any:
if type(value) is dict:
value_type = value.get("type")
if value_type == "date":
date_string = value.get("value")
assert isinstance(date_string, str)
return datetime.date.fromisoformat(date_string)
elif value_type == "datetime":
date_string = value.get("value")
assert isinstance(date_string, str)
return datetime.datetime.fromisoformat(date_string)
return value
else:
return value


@Recorder.register_record_type
class CursorFetchAllRecord(Record):
"""Implements record/replay support for the cursor.fetchall() method."""

params_cls = CursorFetchAllParams
result_cls = CursorFetchAllResult
group = "Database"
23 changes: 23 additions & 0 deletions dbt/adapters/record/cursor/fetchmany.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import dataclasses
from typing import Any, List

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorFetchManyParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchManyResult:
results: List[Any]


@Recorder.register_record_type
class CursorFetchManyRecord(Record):
"""Implements record/replay support for the cursor.fetchmany() method."""

params_cls = CursorFetchManyParams
result_cls = CursorFetchManyResult
group = "Database"
23 changes: 23 additions & 0 deletions dbt/adapters/record/cursor/fetchone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import dataclasses
from typing import Any

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorFetchOneParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchOneResult:
result: Any


@Recorder.register_record_type
class CursorFetchOneRecord(Record):
"""Implements record/replay support for the cursor.fetchone() method."""

params_cls = CursorFetchOneParams
result_cls = CursorFetchOneResult
group = "Database"
23 changes: 23 additions & 0 deletions dbt/adapters/record/cursor/rowcount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import dataclasses
from typing import Optional

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetRowCountParams:
connection_name: str


@dataclasses.dataclass
class CursorGetRowCountResult:
rowcount: Optional[int]


@Recorder.register_record_type
class CursorGetRowCountRecord(Record):
"""Implements record/replay support for the cursor.rowcount property."""

params_cls = CursorGetRowCountParams
result_cls = CursorGetRowCountResult
group = "Database"
24 changes: 24 additions & 0 deletions dbt/adapters/record/handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any

from dbt.adapters.contracts.connection import Connection

from dbt.adapters.record.cursor.cursor import RecordReplayCursor


class RecordReplayHandle:
"""A proxy object used for record/replay modes. What adapters call a
'handle' is typically a native database connection, but should not be
confused with the Connection protocol, which is a dbt-adapters concept.

Currently, the only function of the handle proxy is to provide a record/replay
aware cursor object when cursor() is called."""

def __init__(self, native_handle: Any, connection: Connection) -> None:
self.native_handle = native_handle
self.connection = connection

def cursor(self) -> Any:
# The native handle could be None if we are in replay mode, because no
# actual database access should be performed in that mode.
cursor = None if self.native_handle is None else self.native_handle.cursor()
return RecordReplayCursor(cursor, self.connection)
3 changes: 0 additions & 3 deletions dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtInternalError, NotImplementedError
from dbt_common.record import record_function
from dbt_common.utils import cast_to_str

from dbt.adapters.base import BaseConnectionManager
Expand All @@ -20,7 +19,6 @@
SQLQuery,
SQLQueryStatus,
)
from dbt.adapters.record import QueryRecord

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -143,7 +141,6 @@ def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Tab

return table_from_data_flat(data, column_names)

@record_function(QueryRecord, method=True, tuple_result=True)
def execute(
self,
sql: str,
Expand Down
Loading
Loading