diff --git a/.changes/unreleased/Under the Hood-20231205-170725.yaml b/.changes/unreleased/Under the Hood-20231205-170725.yaml new file mode 100644 index 00000000000..2018825bcff --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231205-170725.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Introduce RelationConfig Protocol, consolidate Relation.create_from +time: 2023-12-05T17:07:25.33861+09:00 +custom: + Author: michelleark + Issue: "9215" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index cf3b78db9d2..dffd4f65f2f 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -423,7 +423,7 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]: """ # the cache only cares about executable nodes return { - self.Relation.create_from(self.config, node).without_identifier() + self.Relation.create_from(self.config, node).without_identifier() # type: ignore[arg-type] for node in manifest.nodes.values() if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node) } @@ -470,7 +470,7 @@ def _get_catalog_relations(self, manifest: Manifest) -> List[BaseRelation]: manifest.sources.values(), ) - relations = [self.Relation.create_from(self.config, n) for n in nodes] + relations = [self.Relation.create_from(self.config, n) for n in nodes] # type: ignore[arg-type] return relations def _relations_cache_for_schemas( diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 70a01398f0d..af508f438e5 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet -from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode from dbt.adapters.contracts.relation import ( + RelationConfig, RelationType, ComponentName, HasQuoting, @@ -11,9 +11,7 @@ Policy, Path, ) -from dbt.common.exceptions import DbtInternalError from dbt.adapters.exceptions import MultipleDatabasesNotAllowedError, ApproximateMatchError -from dbt.node_types import NodeType from dbt.common.utils import filter_null_values, deep_merge from dbt.adapters.utils import classproperty @@ -198,83 +196,50 @@ def quoted(self, identifier): identifier=identifier, ) - @classmethod - def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self: - source_quoting = source.quoting.to_dict(omit_none=True) - source_quoting.pop("column", None) - quote_policy = deep_merge( - cls.get_default_quote_policy().to_dict(omit_none=True), - source_quoting, - kwargs.get("quote_policy", {}), - ) - - return cls.create( - database=source.database, - schema=source.schema, - identifier=source.identifier, - quote_policy=quote_policy, - **kwargs, - ) - @staticmethod def add_ephemeral_prefix(name: str): return f"__dbt__cte__{name}" @classmethod - def create_ephemeral_from_node( + def create_ephemeral_from( cls: Type[Self], - config: HasQuoting, - node: ManifestNode, + relation_config: RelationConfig, ) -> Self: # Note that ephemeral models are based on the name. - identifier = cls.add_ephemeral_prefix(node.name) + identifier = cls.add_ephemeral_prefix(relation_config.name) return cls.create( type=cls.CTE, identifier=identifier, ).quote(identifier=False) @classmethod - def create_from_node( + def create_from( cls: Type[Self], - config: HasQuoting, - node, - quote_policy: Optional[Dict[str, bool]] = None, + quoting: HasQuoting, + relation_config: RelationConfig, **kwargs: Any, ) -> Self: - if quote_policy is None: - quote_policy = {} + quote_policy = kwargs.pop("quote_policy", {}) + + config_quoting = relation_config.quoting_dict + config_quoting.pop("column", None) - quote_policy = dbt.common.utils.merge(config.quoting, quote_policy) + # precedence: kwargs quoting > relation config quoting > base quoting > default quoting + quote_policy = deep_merge( + cls.get_default_quote_policy().to_dict(omit_none=True), + quoting.quoting, + config_quoting, + quote_policy, + ) return cls.create( - database=node.database, - schema=node.schema, - identifier=node.alias, + database=relation_config.database, + schema=relation_config.schema, + identifier=relation_config.identifier, quote_policy=quote_policy, **kwargs, ) - @classmethod - def create_from( - cls: Type[Self], - config: HasQuoting, - node: ResultNode, - **kwargs: Any, - ) -> Self: - if node.resource_type == NodeType.Source: - if not isinstance(node, SourceDefinition): - raise DbtInternalError( - "type mismatch, expected SourceDefinition but got {}".format(type(node)) - ) - return cls.create_from_source(node, **kwargs) - else: - # Can't use ManifestNode here because of parameterized generics - if not isinstance(node, (ParsedNode)): - raise DbtInternalError( - f"type mismatch, expected ManifestNode but got {type(node)}" - ) - return cls.create_from_node(config, node, **kwargs) - @classmethod def create( cls: Type[Self], diff --git a/core/dbt/adapters/contracts/__init__.py b/core/dbt/adapters/contracts/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/core/dbt/adapters/contracts/relation.py b/core/dbt/adapters/contracts/relation.py index c4cead46e45..98e9d8ef878 100644 --- a/core/dbt/adapters/contracts/relation.py +++ b/core/dbt/adapters/contracts/relation.py @@ -22,6 +22,14 @@ class RelationType(StrEnum): Ephemeral = "ephemeral" +class RelationConfig(Protocol): + name: str + database: str + schema: str + identifier: str + quoting_dict: Dict[str, bool] + + class ComponentName(StrEnum): Database = "database" Schema = "schema" diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 45d86bcc307..b182878ae80 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -1,21 +1,11 @@ from dataclasses import dataclass -from typing import ( - Type, - Hashable, - Optional, - ContextManager, - List, - Generic, - TypeVar, - Tuple, -) +from typing import Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, Tuple, Any from typing_extensions import Protocol import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse -from dbt.adapters.contracts.relation import Policy, HasQuoting -from dbt.contracts.graph.nodes import ResultNode +from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.contracts.graph.model_config import BaseConfig from dbt.contracts.graph.manifest import Manifest @@ -42,7 +32,9 @@ def get_default_quote_policy(cls) -> Policy: ... @classmethod - def create_from(cls: Type[Self], config: HasQuoting, node: ResultNode) -> Self: + def create_from( + cls: Type[Self], quoting: HasQuoting, relation_config: RelationConfig, **kwargs: Any + ) -> Self: ... diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index e9094d4518c..2ecb5a0b84a 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -89,11 +89,6 @@ def __init__(self, adapter): def __getattr__(self, key): return getattr(self._relation_type, key) - def create_from_source(self, *args, **kwargs): - # bypass our create when creating from source so as not to mess up - # the source quoting - return self._relation_type.create_from_source(*args, **kwargs) - def create(self, *args, **kwargs): kwargs["quote_policy"] = merge(self._quoting_config, kwargs.pop("quote_policy", {})) return self._relation_type.create(*args, **kwargs) @@ -529,7 +524,7 @@ def resolve( def create_relation(self, target_model: ManifestNode) -> RelationProxy: if target_model.is_ephemeral_model: self.model.set_cte(target_model.unique_id, None) - return self.Relation.create_ephemeral_from_node(self.config, target_model) + return self.Relation.create_ephemeral_from(target_model) else: return self.Relation.create_from(self.config, target_model) @@ -588,7 +583,7 @@ def resolve(self, source_name: str, table_name: str): target_kind="source", disabled=(isinstance(target_source, Disabled)), ) - return self.Relation.create_from_source(target_source) + return self.Relation.create_from(self.config, target_source) # metric` implementations @@ -1475,7 +1470,7 @@ def defer_relation(self) -> Optional[RelationProxy]: object for that stateful other """ if getattr(self.model, "defer_relation", None): - return self.db_wrapper.Relation.create_from_node( + return self.db_wrapper.Relation.create_from( self.config, self.model.defer_relation # type: ignore ) else: diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index fad8a427016..1ab8b9b3e84 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -219,6 +219,13 @@ def __pre_deserialize__(cls, data): data["database"] = None return data + @property + def quoting_dict(self) -> Dict[str, bool]: + if hasattr(self, "quoting"): + return self.quoting.to_dict(omit_none=True) + else: + return {} + @dataclass class MacroDependsOn(dbtClassMixin, Replaceable): diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index d1dabeaf213..e973b5f3592 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -1357,7 +1357,7 @@ def _check_resource_uniqueness( # the full node name is really defined by the adapter's relation relation_cls = get_relation_class_by_name(config.credentials.type) - relation = relation_cls.create_from(config=config, node=node) + relation = relation_cls.create_from(quoting=config, relation_config=node) # type: ignore[arg-type] full_node_name = str(relation) existing_alias = alias_resources.get(full_node_name) diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 089bc7be265..7a782682f65 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -108,7 +108,7 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe # cache the 'other' schemas too! if node.defer_relation: # type: ignore - other_relation = adapter.Relation.create_from_node( + other_relation = adapter.Relation.create_from( self.config, node.defer_relation # type: ignore ) result.add(other_relation.without_identifier()) diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index e4cab9bea15..3f76d751a91 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -99,7 +99,7 @@ def from_run_result(self, result, start_time, timing_info): return result def execute(self, compiled_node, manifest): - relation = self.adapter.Relation.create_from_source(compiled_node) + relation = self.adapter.Relation.create_from(self.config, compiled_node) # given a Source, calculate its freshness. with self.adapter.connection_named(compiled_node.unique_id, compiled_node): self.adapter.clear_transaction()