Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable unit testing in non-root packages #9184

Merged
merged 6 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231130-130948.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support unit tests in non-root packages
time: 2023-11-30T13:09:48.206007-05:00
custom:
Author: gshank
Issue: "8285"
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
limit: Optional[int],
limit: Optional[int] = None,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
Expand Down
72 changes: 35 additions & 37 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@
return self.unit_test_manifest

def parse_unit_test_case(self, test_case: UnitTestDefinition):
package_name = self.root_project.project_name

# Create unit test node based on the node being tested
tested_node = self.manifest.ref_lookup.perform_lookup(
f"model.{package_name}.{test_case.model}", self.manifest
f"model.{test_case.package_name}.{test_case.model}", self.manifest
)
assert isinstance(tested_node, ModelNode)

Expand All @@ -68,7 +66,7 @@
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
package_name=package_name,
package_name=test_case.package_name,
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
Expand All @@ -92,7 +90,7 @@
unit_test_node, # type: ignore
self.root_project,
self.manifest,
package_name,
test_case.package_name,
)
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
# unit_test_node now has a populated refs/sources
Expand Down Expand Up @@ -121,7 +119,7 @@
project_root = self.root_project.project_root
common_fields = {
"resource_type": NodeType.Model,
"package_name": package_name,
"package_name": test_case.package_name,
"original_file_path": original_input_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
Expand All @@ -142,7 +140,7 @@
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
unique_id=f"model.{test_case.package_name}.{input_name}",
name=input_name,
path=original_input_node.path,
)
Expand All @@ -153,7 +151,7 @@
input_name = f"{unit_test_node.name}__{original_input_node.search_name}__{original_input_node.name}"
input_node = UnitTestSourceDefinition(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
unique_id=f"model.{test_case.package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
Expand Down Expand Up @@ -227,35 +225,6 @@
self.schema_parser = schema_parser
self.yaml = yaml

def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
gshank marked this conversation as resolved.
Show resolved Hide resolved
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]

rows: List[Dict[str, Any]] = []

seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)

seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)

if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)

seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)

return rows

def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
Expand Down Expand Up @@ -351,3 +320,32 @@
fqn.append(model_name)
fqn.append(test_name)
return fqn

def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]

rows: List[Dict[str, Any]] = []

seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)

seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)

if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(

Check warning on line 338 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L338

Added line #L338 was not covered by tests
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)

seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)

return rows
114 changes: 114 additions & 0 deletions tests/functional/unit_testing/test_ut_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest
from dbt.tests.util import run_dbt, get_unique_ids_in_results
from dbt.tests.fixtures.project import write_project_files

local_dependency__dbt_project_yml = """

name: 'local_dep'
version: '1.0'

seeds:
quote_columns: False

"""

local_dependency__schema_yml = """
sources:
- name: seed_source
schema: "{{ var('schema_override', target.schema) }}"
tables:
- name: "seed"
columns:
- name: id
tests:
- unique

unit_tests:
- name: test_dep_model_id
model: dep_model
given:
- input: ref('seed')
rows:
- {id: 1, name: Joe}
expect:
rows:
- {name_id: Joe_1}


"""

local_dependency__dep_model_sql = """
select name || '_' || id as name_id from {{ ref('seed') }}

"""

local_dependency__seed_csv = """id,name
1,Mary
2,Sam
3,John
"""

my_model_sql = """
select * from {{ ref('dep_model') }}
"""

my_model_schema_yml = """
unit_tests:
- name: test_my_model_name_id
model: my_model
given:
- input: ref('dep_model')
rows:
- {name_id: Joe_1}
expect:
rows:
- {name_id: Joe_1}
"""


class TestUnitTestingInDependency:
@pytest.fixture(scope="class", autouse=True)
def setUp(self, project_root):
local_dependency_files = {
"dbt_project.yml": local_dependency__dbt_project_yml,
"models": {
"schema.yml": local_dependency__schema_yml,
"dep_model.sql": local_dependency__dep_model_sql,
},
"seeds": {"seed.csv": local_dependency__seed_csv},
}
write_project_files(project_root, "local_dependency", local_dependency_files)

@pytest.fixture(scope="class")
def packages(self):
return {"packages": [{"local": "local_dependency"}]}

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"schema.yml": my_model_schema_yml,
}

def test_unit_test_in_dependency(self, project):
run_dbt(["deps"])
run_dbt(["seed"])
results = run_dbt(["run"])
assert len(results) == 2

results = run_dbt(["test"])
assert len(results) == 3
unique_ids = get_unique_ids_in_results(results)
assert "unit_test.local_dep.dep_model.test_dep_model_id" in unique_ids
gshank marked this conversation as resolved.
Show resolved Hide resolved

results = run_dbt(["test", "--select", "test_type:unit"])
# two unit tests, 1 in root package, one in local_dep package
assert len(results) == 2

results = run_dbt(["test", "--select", "local_dep"])
# 2 tests in local_dep package
assert len(results) == 2

results = run_dbt(["test", "--select", "test"])
# 1 test in root package
assert len(results) == 1