Skip to content

Commit

Permalink
model versions (#7287)
Browse files Browse the repository at this point in the history
model versioning and versioned ref resolution
  • Loading branch information
MichelleArk authored Apr 12, 2023
1 parent 56f8f8a commit c7ebc89
Show file tree
Hide file tree
Showing 33 changed files with 2,193 additions and 305 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230406-101019.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: model versions
time: 2023-04-06T10:10:19.794672-04:00
custom:
Author: michelleark
Issue: '#7263'
90 changes: 63 additions & 27 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
SourceDefinition,
Resource,
ManifestNode,
RefArgs,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
from dbt.events.functions import get_metadata_vars
from dbt.exceptions import (
CompilationError,
Expand Down Expand Up @@ -212,16 +214,17 @@ def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:

class BaseRefResolver(BaseResolver):
@abc.abstractmethod
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
def resolve(
self, name: str, package: Optional[str] = None, version: Optional[NodeVersion] = None
) -> RelationProxy:
...

def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
if package is None:
return [name]
else:
return [package, name]
def _repack_args(
self, name: str, package: Optional[str], version: Optional[NodeVersion]
) -> RefArgs:
return RefArgs(package=package, name=name, version=version)

def validate_args(self, name: str, package: Optional[str]):
def validate_args(self, name: str, package: Optional[str], version: Optional[NodeVersion]):
if not isinstance(name, str):
raise CompilationError(
f"The name argument to ref() must be a string, got {type(name)}"
Expand All @@ -232,18 +235,26 @@ def validate_args(self, name: str, package: Optional[str]):
f"The package argument to ref() must be a string or None, got {type(package)}"
)

def __call__(self, *args: str) -> RelationProxy:
if version is not None and not isinstance(version, (str, int, float)):
raise CompilationError(
f"The version argument to ref() must be a string, int, float, or None - got {type(version)}"
)

def __call__(self, *args: str, **kwargs) -> RelationProxy:
name: str
package: Optional[str] = None
version: Optional[NodeVersion] = None

if len(args) == 1:
name = args[0]
elif len(args) == 2:
package, name = args
else:
raise RefArgsError(node=self.model, args=args)
self.validate_args(name, package)
return self.resolve(name, package)

version = kwargs.get("version") or kwargs.get("v")
self.validate_args(name, package, version)
return self.resolve(name, package, version)


class BaseSourceResolver(BaseResolver):
Expand Down Expand Up @@ -448,8 +459,10 @@ def __getattr__(self, name):

# `ref` implementations
class ParseRefResolver(BaseRefResolver):
def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:
self.model.refs.append(self._repack_args(name, package))
def resolve(
self, name: str, package: Optional[str] = None, version: Optional[NodeVersion] = None
) -> RelationProxy:
self.model.refs.append(self._repack_args(name, package, version))

return self.Relation.create_from(self.config, self.model)

Expand All @@ -458,10 +471,16 @@ def resolve(self, name: str, package: Optional[str] = None) -> RelationProxy:


class RuntimeRefResolver(BaseRefResolver):
def resolve(self, target_name: str, target_package: Optional[str] = None) -> RelationProxy:
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_model = self.manifest.resolve_ref(
target_name,
target_package,
target_version,
self.current_project,
self.model.package_name,
)
Expand All @@ -472,23 +491,28 @@ def resolve(self, target_name: str, target_package: Optional[str] = None) -> Rel
target_name=target_name,
target_kind="node",
target_package=target_package,
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)
self.validate(target_model, target_name, target_package)
return self.create_relation(target_model, target_name)
self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)

def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
else:
return self.Relation.create_from(self.config, target_model)

def validate(
self, resolved: ManifestNode, target_name: str, target_package: Optional[str]
self,
resolved: ManifestNode,
target_name: str,
target_package: Optional[str],
target_version: Optional[NodeVersion],
) -> None:
if resolved.unique_id not in self.model.depends_on.nodes:
args = self._repack_args(target_name, target_package)
args = self._repack_args(target_name, target_package, target_version)
raise RefBadContextError(node=self.model, args=args)


Expand All @@ -498,16 +522,17 @@ def validate(
resolved: ManifestNode,
target_name: str,
target_package: Optional[str],
target_version: Optional[NodeVersion],
) -> None:
pass

def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy:
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
# In operations, we can't ref() ephemeral nodes, because
# Macros do not support set_cte
raise OperationsCannotRefEphemeralNodesError(target_model.name, node=self.model)
else:
return super().create_relation(target_model, name)
return super().create_relation(target_model)


# `source` implementations
Expand Down Expand Up @@ -1408,10 +1433,18 @@ def generate_runtime_macro_context(


class ExposureRefResolver(BaseResolver):
def __call__(self, *args) -> str:
if len(args) not in (1, 2):
def __call__(self, *args, **kwargs) -> str:
package = None
if len(args) == 1:
name = args[0]
elif len(args) == 2:
package, name = args
else:
raise RefArgsError(node=self.model, args=args)
self.model.refs.append(list(args))

version = kwargs.get("version") or kwargs.get("v")

self.model.refs.append(RefArgs(package=package, name=name, version=version))
return ""


Expand Down Expand Up @@ -1461,19 +1494,22 @@ def generate_parse_exposure(


class MetricRefResolver(BaseResolver):
def __call__(self, *args) -> str:
def __call__(self, *args, **kwargs) -> str:
package = None
if len(args) == 1:
name = args[0]
elif len(args) == 2:
package, name = args
else:
raise RefArgsError(node=self.model, args=args)
self.validate_args(name, package)
self.model.refs.append(list(args))

version = kwargs.get("version") or kwargs.get("v")
self.validate_args(name, package, version)

self.model.refs.append(RefArgs(package=package, name=name, version=version))
return ""

def validate_args(self, name, package):
def validate_args(self, name, package, version):
if not isinstance(name, str):
raise ParsingError(
f"In a metrics section in {self.model.original_file_path} "
Expand Down
40 changes: 32 additions & 8 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ResultNode,
BaseNode,
)
from dbt.contracts.graph.unparsed import SourcePatch
from dbt.contracts.graph.unparsed import SourcePatch, NodeVersion
from dbt.contracts.graph.manifest_upgrade import upgrade_manifest_json
from dbt.contracts.files import SourceFile, SchemaSourceFile, FileHash, AnySourceFile
from dbt.contracts.util import BaseArtifactMetadata, SourceKey, ArtifactMixin, schema_version
Expand Down Expand Up @@ -146,6 +146,7 @@ def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SourceDef
class RefableLookup(dbtClassMixin):
# model, seed, snapshot
_lookup_types: ClassVar[set] = set(NodeType.refable())
_versioned_types: ClassVar[set] = set(NodeType.versioned())

# refables are actually unique, so the Dict[PackageName, UniqueID] will
# only ever have exactly one value, but doing 3 dict lookups instead of 1
Expand All @@ -154,11 +155,19 @@ def __init__(self, manifest: "Manifest"):
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

def get_unique_id(self, key, package: Optional[PackageName]):
def get_unique_id(self, key, package: Optional[PackageName], version: Optional[NodeVersion]):
if version:
key = f"{key}.v{version}"
return find_unique_id_for_package(self.storage, key, package)

def find(self, key, package: Optional[PackageName], manifest: "Manifest"):
unique_id = self.get_unique_id(key, package)
def find(
self,
key,
package: Optional[PackageName],
version: Optional[NodeVersion],
manifest: "Manifest",
):
unique_id = self.get_unique_id(key, package, version)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None
Expand All @@ -167,7 +176,15 @@ def add_node(self, node: ManifestNode):
if node.resource_type in self._lookup_types:
if node.name not in self.storage:
self.storage[node.name] = {}
self.storage[node.name][node.package_name] = node.unique_id

if node.resource_type in self._versioned_types and node.version:
if node.search_name not in self.storage:
self.storage[node.search_name] = {}
self.storage[node.search_name][node.package_name] = node.unique_id
if node.is_latest_version: # type: ignore
self.storage[node.name][node.package_name] = node.unique_id
else:
self.storage[node.name][node.package_name] = node.unique_id

def populate(self, manifest):
for node in manifest.nodes.values():
Expand Down Expand Up @@ -233,7 +250,12 @@ def add_node(self, node):

# This should return a list of disabled nodes. It's different from
# the other Lookup functions in that it returns full nodes, not just unique_ids
def find(self, search_name, package: Optional[PackageName]):
def find(
self, search_name, package: Optional[PackageName], version: Optional[NodeVersion] = None
):
if version:
search_name = f"{search_name}.v{version}"

if search_name not in self.storage:
return None

Expand All @@ -252,6 +274,7 @@ def find(self, search_name, package: Optional[PackageName]):

class AnalysisLookup(RefableLookup):
_lookup_types: ClassVar[set] = set([NodeType.Analysis])
_versioned_types: ClassVar[set] = set()


def _search_packages(
Expand Down Expand Up @@ -900,6 +923,7 @@ def resolve_ref(
self,
target_model_name: str,
target_model_package: Optional[str],
target_model_version: Optional[NodeVersion],
current_project: str,
node_package: str,
) -> MaybeNonSource:
Expand All @@ -909,14 +933,14 @@ def resolve_ref(

candidates = _search_packages(current_project, node_package, target_model_package)
for pkg in candidates:
node = self.ref_lookup.find(target_model_name, pkg, self)
node = self.ref_lookup.find(target_model_name, pkg, target_model_version, self)

if node is not None and node.config.enabled:
return node

# it's possible that the node is disabled
if disabled is None:
disabled = self.disabled_lookup.find(target_model_name, pkg)
disabled = self.disabled_lookup.find(target_model_name, pkg, target_model_version)

if disabled:
return Disabled(disabled[0])
Expand Down
18 changes: 18 additions & 0 deletions core/dbt/contracts/graph/manifest_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,24 @@ def rename_sql_attr(node_content: dict) -> dict:
return node_content


def upgrade_ref_content(node_content: dict) -> dict:
# In v1.5 we switched Node.refs from List[List[str]] to List[Dict[str, Union[NodeVersion, str]]]
# Previous versions did not have a version keyword argument for ref
if "refs" in node_content:
upgraded_refs = []
for ref in node_content["refs"]:
if isinstance(ref, list):
if len(ref) == 1:
upgraded_refs.append({"package": None, "name": ref[0], "version": None})
else:
upgraded_refs.append({"package": ref[0], "name": ref[1], "version": None})
node_content["refs"] = upgraded_refs
return node_content


def upgrade_node_content(node_content):
rename_sql_attr(node_content)
upgrade_ref_content(node_content)
if node_content["resource_type"] != "seed" and "root_path" in node_content:
del node_content["root_path"]

Expand Down Expand Up @@ -92,9 +108,11 @@ def upgrade_manifest_json(manifest: dict) -> dict:
for metric_content in manifest.get("metrics", {}).values():
# handle attr renames + value translation ("expression" -> "derived")
metric_content = rename_metric_attr(metric_content)
metric_content = upgrade_ref_content(metric_content)
if "root_path" in metric_content:
del metric_content["root_path"]
for exposure_content in manifest.get("exposures", {}).values():
exposure_content = upgrade_ref_content(exposure_content)
if "root_path" in exposure_content:
del exposure_content["root_path"]
for source_content in manifest.get("sources", {}).values():
Expand Down
Loading

0 comments on commit c7ebc89

Please sign in to comment.