Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ jobs:
poetry env use ${{ matrix.python-version }}
poetry add "apache-airflow~=${{ matrix.airflow-version }}.0" \
"dbt-core~=${{ matrix.dbt-version }}.0" \
"dbt-postgres~=${{ matrix.dbt-version }}.0" \
--python ${{ matrix.python-version }}
poetry install -E postgres --with dev
poetry install -E adapters --with dev
poetry run airflow db migrate
poetry run airflow connections create-default-connections

Expand Down
205 changes: 164 additions & 41 deletions airflow_dbt_python/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

import json
import logging
import os
import re
import sys
from contextlib import contextmanager
from copy import copy
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -69,6 +73,7 @@ class DbtConnectionParam(NamedTuple):
name: str
store_override_name: Optional[str] = None
default: Optional[Any] = None
depends_on: Callable[[Connection], bool] = lambda x: True

@property
def override_name(self):
Expand Down Expand Up @@ -119,9 +124,8 @@ class DbtHook(BaseHook):
conn_params: list[Union[DbtConnectionParam, str]] = [
DbtConnectionParam("conn_type", "type"),
"host",
DbtConnectionParam("conn_id", "dbname"),
"schema",
DbtConnectionParam("login", "user"),
"login",
"password",
"port",
]
Expand Down Expand Up @@ -240,8 +244,6 @@ def run_dbt_task(
nearest_project_dir = get_nearest_project_dir(config.project_dir)

with chdir_ctx(nearest_project_dir):
self.ensure_profiles(config)

with adapter_management():
task, runtime_config = config.create_dbt_task(
extra_target, write_perf_info
Expand Down Expand Up @@ -412,19 +414,6 @@ def setup_dbt_logging(self, debug: Optional[bool]):
configured_file.setLevel("INFO")
configured_file.propagate = False

def ensure_profiles(self, config: BaseConfig):
"""Ensure a profiles file exists."""
if config.profiles_dir is not None:
# We expect one to exist given that we have passed a profiles_dir.
return

profiles_path = Path.home() / ".dbt/profiles.yml"
config.profiles_dir = str(profiles_path.parent)
if not profiles_path.exists():
profiles_path.parent.mkdir(exist_ok=True)
with profiles_path.open("w", encoding="utf-8") as f:
f.write("flags:\n send_anonymous_usage_stats: false\n")

def get_dbt_target_from_connection(
self, target: Optional[str]
) -> Optional[dict[str, Any]]:
Expand Down Expand Up @@ -454,11 +443,14 @@ def get_dbt_target_from_connection(
)
return None

details = self.get_dbt_details_from_connection(conn)
db_hook_class = self.get_db_hook_class(conn)

details = db_hook_class.get_dbt_details_from_connection(conn)

return {conn_id: details}

def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
@classmethod
def get_dbt_details_from_connection(cls, conn: Connection) -> dict[str, Any]:
"""Extract dbt connection details from Airflow Connection.

dbt connection details may be present as Airflow Connection attributes or in the
Expand All @@ -476,8 +468,10 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
A dictionary of dbt connection details.
"""
dbt_details = {}
for param in self.conn_params:
for param in cls.conn_params:
if isinstance(param, DbtConnectionParam):
if not param.depends_on(conn):
continue
key = param.override_name
value = getattr(conn, param.name, param.default)
else:
Expand All @@ -491,11 +485,13 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:

extra = conn.extra_dejson

if not self.conn_extra_params:
if not cls.conn_extra_params:
return {**dbt_details, **extra}

for param in self.conn_extra_params:
for param in cls.conn_extra_params:
if isinstance(param, DbtConnectionParam):
if not param.depends_on(conn):
continue
key = param.override_name
value = extra.get(param.name, param.default)
else:
Expand All @@ -509,43 +505,91 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:

return dbt_details

@classmethod
def get_db_hook_class(cls, conn: Connection) -> type[DbtHook]:
"""Get a dbt hook class depend on Airflow connection type."""
known_dbt_hooks = (
DbtPostgresHook,
DbtRedshiftHook,
DbtSnowflakeHook,
DbtBigQueryHook,
DbtSparkHook,
)
for hook in known_dbt_hooks:
if hook.conn_type == conn.conn_type:
return hook
return DbtHook


class DbtPostgresHook(DbtHook):
"""A hook to interact with dbt using a Postgres connection."""

conn_type = "postgres"
hook_name = "dbt Postgres Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", "postgres"),
DbtConnectionParam("conn_type", "type", conn_type),
"host",
"schema",
DbtConnectionParam("schema", default="public"),
DbtConnectionParam("login", "user"),
"password",
"port",
DbtConnectionParam("port", default=5432),
]
conn_extra_params = [
"dbname",
"threads",
"keepalives_idle",
DbtConnectionParam("dbname", "database", "postgres"),
"connect_timeout",
"retries",
"search_path",
"role",
"search_path",
"keepalives_idle",
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"retries",
]

@classmethod
def get_dbt_details_from_connection(cls, conn: Connection) -> dict[str, Any]:
"""Extract dbt connection details from Airflow Connection.

dbt connection details may be present as Airflow Connection attributes or in the
Connection's extras. This class' conn_params and conn_extra_params will be used
to fetch required attributes from attributes and extras respectively. If
conn_extra_params is empty, we merge parameters with all extras.

Subclasses may override this class attributes to narrow down the connection
details for a specific dbt target (like Postgres, or Redshift).

Args:
conn: The Airflow Connection to extract dbt connection details from.

Returns:
A dictionary of dbt connection details.
"""
if "options" in conn.extra_dejson:
conn = copy(conn)
extra_dejson = conn.extra_dejson
options = extra_dejson.pop("options")
for k, v in re.findall(r"-c (\w+)=(.*)$", options):
Copy link
Owner

@tomasfarias tomasfarias Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue(blocking): Could we modify the docstring (or add a comment) explaining what are we trying to find with this regular expression?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now in a different module, but the same regular expression is present.

extra_dejson[k] = v
conn.extra = json.dumps(extra_dejson)
return super().get_dbt_details_from_connection(conn)


class DbtRedshiftHook(DbtPostgresHook):
"""A hook to interact with dbt using a Redshift connection."""

conn_type = "redshift"
hook_name = "dbt Redshift Hook"
conn_extra_params = DbtPostgresHook.conn_extra_params + [
"ra3_node",
"method",
"cluster_id",
"iam_profile",
"iam_duration_secons",
"autocreate",
"db_groups",
"ra3_node",
"connect_timeout",
"role",
"region",
]


Expand All @@ -555,22 +599,101 @@ class DbtSnowflakeHook(DbtHook):
conn_type = "snowflake"
hook_name = "dbt Snowflake Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", "postgres"),
DbtConnectionParam("conn_type", "type", conn_type),
"host",
"schema",
DbtConnectionParam("login", "user"),
"password",
DbtConnectionParam(
"login",
"user",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") != "oauth",
Copy link
Owner

@tomasfarias tomasfarias Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: I feel like we could just have a utility function that we can re-use here instead of using lambda:

def make_extra_dejson_compare_callable(key, default, comparison_operator, expected):
    def compare_extra_dejson_value(conn):
        return comparison_operator(conn.extra_dejson.get(key, default), expected))
        
    return compare_extra_dejson_value

Then here:

import operator
...
        DbtConnectionParam(
            "login",
            "user",
            depends_on=make_extra_dejson_compare_callable("authenticator", "", operator.ne, "oauth"),
        ),

This one is kind of a nitpick. I do prefer an approach like this to using lambda, but I think the other comments are more important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked this idea and decided to combine recurring parameters together

),
DbtConnectionParam(
"login",
"oauth_client_id",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
),
DbtConnectionParam(
"password",
depends_on=lambda x: not any(
(
*(k in x.extra_dejson for k in ("private_key_file", "private_key_content")),
Copy link
Owner

@tomasfarias tomasfarias Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue(blocking): I don't think we care whether there are private key file and contents in extra_dejson as long as we have our password and .get("authenticator", "") == "oauth".

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep the mental model small for users with only one key to consider to determine what everything else means (authenticator).

x.extra_dejson.get("authenticator", "") == "oauth",
),
),
),
DbtConnectionParam(
"password",
"private_key_passphrase",
depends_on=lambda x: any(
k in x.extra_dejson for k in ("private_key_file", "private_key_content")
),
),
DbtConnectionParam(
"password",
"oauth_client_secret",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
),
]
conn_extra_params = [
"account",
"role",
"database",
"warehouse",
"threads",
"client_session_keep_alive",
"role",
"authenticator",
"query_tag",
"connect_retries",
"client_session_keep_alive",
"connect_timeout",
"retry_on_database_errors",
"retry_all",
"reuse_connections",
"account",
"database",
DbtConnectionParam("refresh_token", "token"),
DbtConnectionParam("private_key_file", "private_key_path"),
DbtConnectionParam("private_key_content", "private_key"),
]


class DbtBigQueryHook(DbtHook):
"""A hook to interact with dbt using a BigQuery connection."""

conn_type = "bigquery"
hook_name = "dbt BigQuery Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"schema",
]
conn_extra_params = [
DbtConnectionParam("keyfile_path", "keyfile"),
DbtConnectionParam("keyfile_dict", "keyfile_json"),
"method",
"database",
"schema",
"refresh_token",
"client_id",
"client_secret",
"token_uri",
"OPTIONAL_CONFIG",
]


class DbtSparkHook(DbtHook):
"""A hook to interact with dbt using a Spark connection."""

conn_type = "spark"
hook_name = "dbt Spark Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"host",
"port",
"schema",
DbtConnectionParam("login", "user"),
DbtConnectionParam(
"password",
depends_on=lambda x: x.extra_dejson.get("method", "") == "thrift",
),
DbtConnectionParam(
"password",
"token",
depends_on=lambda x: x.extra_dejson.get("method", "") != "thrift",
),
]
conn_extra_params = []
3 changes: 1 addition & 2 deletions airflow_dbt_python/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ def _runtime_initialize():
task._flattened_nodes.append(task.manifest.sources[uid])
else:
raise DbtException(
f"Node selection returned {uid}, expected a node or a "
f"source"
f"Node selection returned {uid}, expected a node or a source"
)
task.num_nodes = len(
[n for n in task._flattened_nodes if not n.is_ephemeral_model]
Expand Down
Loading