Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ repos:
- id: rst-backticks
- id: python-check-mock-methods
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.3
rev: v1.5.1
Comment thread
harels marked this conversation as resolved.
hooks:
- id: remove-crlf
- id: remove-tabs
Expand Down
3 changes: 3 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml")
DEFAULT_DBT_PROFILE_NAME = "cosmos_profile"
DEFAULT_DBT_TARGET_NAME = "cosmos_target"
DBT_LOG_PATH_ENVVAR = "DBT_LOG_PATH"
DBT_TARGET_PATH_ENVVAR = "DBT_TARGET_PATH"
DBT_LOG_FILENAME = "dbt.log"


class LoadMode(Enum):
Expand Down
105 changes: 63 additions & 42 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import json
import os
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from subprocess import Popen, PIPE
from typing import Any

from cosmos.config import ProfileConfig
from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode
from cosmos.constants import (
DbtResourceType,
ExecutionMode,
LoadMode,
DBT_LOG_FILENAME,
DBT_LOG_PATH_ENVVAR,
DBT_TARGET_PATH_ENVVAR,
)
from cosmos.dbt.executable import get_system_dbt
from cosmos.dbt.parser.project import DbtProject as LegacyDbtProject
from cosmos.dbt.project import DbtProject
Expand All @@ -18,8 +26,6 @@

logger = get_logger(__name__)

# TODO replace inline constants


class CosmosLoadDbtException(Exception):
"""
Expand Down Expand Up @@ -143,9 +149,12 @@ def load_via_dbt_ls(self) -> None:
if self.select:
command.extend(["--select", *self.select])

with self.profile_config.ensure_profile() as (profile_path, env_vars):
with self.profile_config.ensure_profile() as profile_values:
(profile_path, env_vars) = profile_values
command.extend(
[
"--project-dir",
str(self.project.dir),
"--profiles-dir",
str(profile_path.parent),
"--profile",
Expand All @@ -158,46 +167,58 @@ def load_via_dbt_ls(self) -> None:
env = os.environ.copy()
env.update(env_vars)

logger.info("Running command: `%s`", " ".join(command))
logger.info("Environment variable keys: %s", env.keys())
process = Popen(
command,
stdout=PIPE,
stderr=PIPE,
cwd=self.project.dir,
universal_newlines=True,
env=env,
)

stdout, stderr = process.communicate()

logger.debug("dbt output:\n %s", stdout)

if stderr or "Runtime Error" in stdout:
details = stderr or stdout
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{details}")

nodes = {}
for line in stdout.split("\n"):
try:
node_dict = json.loads(line.strip())
except json.decoder.JSONDecodeError:
logger.debug("Skipped dbt ls line: %s", line)
else:
node = DbtNode(
name=node_dict["name"],
unique_id=node_dict["unique_id"],
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dir / node_dict["original_file_path"],
tags=node_dict["tags"],
config=node_dict["config"],
with tempfile.TemporaryDirectory() as tmpdir:
logger.info("Running command: `%s`", " ".join(command))
logger.info("Environment variable keys: %s", env.keys())
log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir)
target_dir = Path(env.get(DBT_TARGET_PATH_ENVVAR) or tmpdir)
env[DBT_LOG_PATH_ENVVAR] = str(log_dir)
env[DBT_TARGET_PATH_ENVVAR] = str(target_dir)

process = Popen(
command,
stdout=PIPE,
stderr=PIPE,
cwd=tmpdir,
universal_newlines=True,
env=env,
)
nodes[node.unique_id] = node
logger.debug("Parsed dbt resource `%s` of type `%s`", node.unique_id, node.resource_type)

self.nodes = nodes
self.filtered_nodes = nodes
stdout, stderr = process.communicate()

logger.debug("dbt output: %s", stdout)
log_filepath = log_dir / DBT_LOG_FILENAME
logger.debug("dbt logs available in: %s", log_filepath)
if log_filepath.exists():
with open(log_filepath) as logfile:
for line in logfile:
logger.debug(line.strip())
Comment thread
tatiana marked this conversation as resolved.

if stderr or "Runtime Error" in stdout:
details = stderr or stdout
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{details}")

nodes = {}
for line in stdout.split("\n"):
try:
node_dict = json.loads(line.strip())
except json.decoder.JSONDecodeError:
logger.debug("Skipped dbt ls line: %s", line)
else:
node = DbtNode(
name=node_dict["name"],
unique_id=node_dict["unique_id"],
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dir / node_dict["original_file_path"],
tags=node_dict["tags"],
config=node_dict["config"],
)
nodes[node.unique_id] = node
logger.debug("Parsed dbt resource `%s` of type `%s`", node.unique_id, node.resource_type)

self.nodes = nodes
self.filtered_nodes = nodes

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))
Expand Down
48 changes: 48 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch

Expand All @@ -10,10 +12,28 @@
from cosmos.profiles import PostgresUserPasswordProfileMapping

DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt"
DBT_PIPELINE_NAME = "jaffle_shop"
SAMPLE_MANIFEST = Path(__file__).parent.parent / "sample/manifest.json"
SAMPLE_MANIFEST_PY = Path(__file__).parent.parent / "sample/manifest_python.json"


@pytest.fixture
def tmp_dbt_project_dir():
"""
Creates a plain dbt project structure, which does not contain logs or target folders.
"""
source_proj_dir = DBT_PROJECTS_ROOT_DIR / DBT_PIPELINE_NAME

tmp_dir = Path(tempfile.mkdtemp())
target_proj_dir = tmp_dir / DBT_PIPELINE_NAME
shutil.copytree(source_proj_dir, target_proj_dir)
shutil.rmtree(target_proj_dir / "logs", ignore_errors=True)
shutil.rmtree(target_proj_dir / "target", ignore_errors=True)
yield tmp_dir

shutil.rmtree(tmp_dir, ignore_errors=True) # delete directory


@pytest.mark.parametrize(
"pipeline_name,manifest_filepath,model_filepath",
[("jaffle_shop", SAMPLE_MANIFEST, "customers.sql"), ("jaffle_shop_python", SAMPLE_MANIFEST_PY, "customers.py")],
Expand Down Expand Up @@ -108,6 +128,34 @@ def test_load(
assert load_function.called


@pytest.mark.integration
@patch("cosmos.dbt.graph.Popen")
def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_popen, tmp_dbt_project_dir):
mock_popen().communicate.return_value = ("", "")
assert not (tmp_dbt_project_dir / "target").exists()
assert not (tmp_dbt_project_dir / "logs").exists()

dbt_project = DbtProject(name=DBT_PIPELINE_NAME, root_dir=tmp_dbt_project_dir)
dbt_graph = DbtGraph(
project=dbt_project,
profile_config=ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="airflow_db",
profile_args={"schema": "public"},
),
),
)
dbt_graph.load_via_dbt_ls()
assert not (tmp_dbt_project_dir / "target").exists()
assert not (tmp_dbt_project_dir / "logs").exists()

used_cwd = Path(mock_popen.call_args[0][0][-5])
assert used_cwd != dbt_project.dir
assert not used_cwd.exists()


@pytest.mark.integration
def test_load_via_dbt_ls_with_exclude():
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR)
Expand Down