diff --git a/.changes/unreleased/Under the Hood-20231205-165812.yaml b/.changes/unreleased/Under the Hood-20231205-165812.yaml new file mode 100644 index 00000000000..8dcf402535c --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231205-165812.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Remove usage of dbt.contracts.graph.nodes.ResultNode in dbt/adapters +time: 2023-12-05T16:58:12.932172+09:00 +custom: + Author: michelleark + Issue: "9214" diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 244295994ee..80ebf322523 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -146,7 +146,7 @@ def exception_handler(self, sql: str) -> ContextManager: def set_connection_name(self, name: Optional[str] = None) -> Connection: """Called by 'acquire_connection' in BaseAdapter, which is called by - 'connection_named', called by 'connection_for(node)'. + 'connection_named'. Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index e4806fdc8b7..cf3b78db9d2 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -63,7 +63,6 @@ ) from dbt.common.clients.jinja import CallableMacroGenerator from dbt.contracts.graph.manifest import Manifest, MacroManifest -from dbt.contracts.graph.nodes import ResultNode from dbt.common.events.functions import fire_event, warn_or_error from dbt.adapters.events.types import ( CacheMiss, @@ -285,10 +284,10 @@ def nice_connection_name(self) -> str: return conn.name @contextmanager - def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]: + def connection_named(self, name: str, query_header_context: Any = None) -> Iterator[None]: try: if self.connections.query_header is not None: - self.connections.query_header.set(name, node) + self.connections.query_header.set(name, query_header_context) self.acquire_connection(name) yield finally: @@ -296,11 +295,6 @@ def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iter if self.connections.query_header is not None: self.connections.query_header.reset() - @contextmanager - def connection_for(self, node: ResultNode) -> Iterator[None]: - with self.connection_named(node.unique_id, node): - yield - @available.parse(lambda *a, **k: ("", empty_table())) def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None diff --git a/core/dbt/adapters/base/query_headers.py b/core/dbt/adapters/base/query_headers.py index 249cf263d14..6fa591d45c8 100644 --- a/core/dbt/adapters/base/query_headers.py +++ b/core/dbt/adapters/base/query_headers.py @@ -5,17 +5,16 @@ from dbt.context.manifest import generate_query_header_context from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment -from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.graph.manifest import Manifest from dbt.common.exceptions import DbtRuntimeError -class NodeWrapper: - def __init__(self, node) -> None: - self._inner_node = node +class QueryHeaderContextWrapper: + def __init__(self, context) -> None: + self._inner_context = context def __getattr__(self, name): - return getattr(self._inner_node, name, "") + return getattr(self._inner_context, name, "") class _QueryComment(local): @@ -53,7 +52,7 @@ def set(self, comment: Optional[str], append: bool): self.append = append -QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str] +QueryStringFunc = Callable[[str, Optional[QueryHeaderContextWrapper]], str] class MacroQueryStringSetter: @@ -90,10 +89,10 @@ def add(self, sql: str) -> str: def reset(self): self.set("master", None) - def set(self, name: str, node: Optional[ResultNode]): - wrapped: Optional[NodeWrapper] = None - if node is not None: - wrapped = NodeWrapper(node) + def set(self, name: str, query_header_context: Any): + wrapped: Optional[QueryHeaderContextWrapper] = None + if query_header_context is not None: + wrapped = QueryHeaderContextWrapper(query_header_context) comment_str = self.generator(name, wrapped) append = False diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 59664a72686..516130672da 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -298,7 +298,9 @@ def from_run_result(self, result, start_time, timing_info): def compile_and_execute(self, manifest, ctx): result = None - with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext(): + with self.adapter.connection_named( + self.node.unique_id, self.node + ) if get_flags().INTROSPECT else nullcontext(): ctx.node.update_event_status(node_status=RunningStatus.Compiling) fire_event( NodeCompiling( diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index 3e71345db6c..e4cab9bea15 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -101,7 +101,7 @@ def from_run_result(self, result, start_time, timing_info): def execute(self, compiled_node, manifest): relation = self.adapter.Relation.create_from_source(compiled_node) # given a Source, calculate its freshness. - with self.adapter.connection_for(compiled_node): + with self.adapter.connection_named(compiled_node.unique_id, compiled_node): self.adapter.clear_transaction() adapter_response: Optional[AdapterResponse] = None freshness = None