Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 21 additions & 340 deletions airflow_dbt_python/hooks/dbt.py

Large diffs are not rendered by default.

351 changes: 351 additions & 0 deletions airflow_dbt_python/hooks/target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
"""Provides a hook to get a dbt profile based on the Airflow connection."""

from __future__ import annotations

import json
import re
import warnings
from abc import ABC, ABCMeta
from copy import copy
from typing import (
Any,
Callable,
ClassVar,
NamedTuple,
Optional,
Union,
)

from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection


class DbtConnectionParam(NamedTuple):
"""A tuple indicating connection parameters relevant to dbt.
Attributes:
name: The name of the connection parameter. This name will be used to get the
parameter from an Airflow Connection or its extras.
store_override_name: A new name for the connection parameter. If not None, this
is the name used in a dbt profiles.
default: A default value if the parameter is not found.
"""

name: str
store_override_name: Optional[str] = None
default: Optional[Any] = None
depends_on: Callable[[Connection], bool] = lambda x: True

@property
def override_name(self):
"""Returns the override_name if defined, otherwise defaults to name.
>>> DbtConnectionParam("login", "user").override_name
'user'
>>> DbtConnectionParam("port").override_name
'port'
"""
if self.store_override_name is None:
return self.name
return self.store_override_name


class DbtConnectionHookMeta(ABCMeta):
"""A hook metaclass to collect all subclasses of DbtConnectionHook."""

_dbt_hooks_by_conn_type: ClassVar[dict[str, DbtConnectionHookMeta]] = {}
conn_type: str

def __new__(cls, name, bases, attrs, **kwargs) -> DbtConnectionHookMeta:
"""Adds each DbtConnectionHook subclass to the dict based on its conn_type."""
new_hook_cls = super().__new__(cls, name, bases, attrs)
if new_hook_cls.conn_type in cls._dbt_hooks_by_conn_type:
warnings.warn(
f"The connection type `{new_hook_cls.conn_type}`"
f" has been overwritten by `{new_hook_cls}`",
UserWarning,
stacklevel=1,
)

cls._dbt_hooks_by_conn_type[new_hook_cls.conn_type] = new_hook_cls
return new_hook_cls


class DbtConnectionHook(BaseHook, ABC, metaclass=DbtConnectionHookMeta):
"""A hook to get a dbt profile based on the Airflow connection."""

conn_type = "dbt"
hook_name = "dbt Hook"

conn_params: list[Union[DbtConnectionParam, str]] = [
DbtConnectionParam("conn_type", "type"),
"host",
"schema",
"login",
"password",
"port",
]
conn_extra_params: list[Union[DbtConnectionParam, str]] = []

def __init__(
self,
*args,
conn: Connection,
**kwargs,
):
super().__init__(*args, **kwargs)
self.conn = conn

@classmethod
def get_db_conn_hook(cls, conn_id: str) -> DbtConnectionHook:
"""Get a dbt hook class depend on Airflow connection type."""
conn = cls.get_connection(conn_id)

if hook_cls := cls._dbt_hooks_by_conn_type.get(conn.conn_type):
return hook_cls(conn=conn)
raise KeyError(
f"There are no DbtConnectionHook subclasses with conn_type={conn.conn_type}"
)

def get_dbt_target_from_connection(self) -> Optional[dict[str, Any]]:
"""Return a dictionary of connection details to use as a dbt target.
The connection details are fetched from an Airflow connection identified by
self.dbt_conn_id.
Returns:
A dictionary with a configuration for a dbt target, or None if a matching
Airflow connection is not found for given dbt target.
"""
details = self.get_dbt_details_from_connection(self.conn)

return {self.conn.conn_id: details}

def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
Copy link
Owner

@tomasfarias tomasfarias Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: Okay, this is no longer a classmethod, safe to ignore my comment on this topic. I think this is now fine.

"""Extract dbt connection details from Airflow Connection.
dbt connection details may be present as Airflow Connection attributes or in the
Connection's extras. This class' conn_params and conn_extra_params will be used
to fetch required attributes from attributes and extras respectively. If
conn_extra_params is empty, we merge parameters with all extras.
Subclasses may override this class attributes to narrow down the connection
details for a specific dbt target (like Postgres, or Redshift).
Returns:
A dictionary of dbt connection details.
"""
dbt_details = {}
for param in self.conn_params:
if isinstance(param, DbtConnectionParam):
if not param.depends_on(conn):
continue
key = param.override_name
value = getattr(conn, param.name, param.default)
else:
key = param
value = getattr(conn, key, None)

if value is None:
continue

dbt_details[key] = value

extra = conn.extra_dejson

if not self.conn_extra_params:
return {**dbt_details, **extra}

for param in self.conn_extra_params:
if isinstance(param, DbtConnectionParam):
if not param.depends_on(conn):
continue
key = param.override_name
value = extra.get(param.name, param.default)
else:
key = param
value = extra.get(key, None)

if value is None:
continue

dbt_details[key] = value

return dbt_details


class DbtPostgresHook(DbtConnectionHook):
"""A hook to interact with dbt using a Postgres connection."""

conn_type = "postgres"
hook_name = "dbt Postgres Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"host",
DbtConnectionParam("schema", default="public"),
DbtConnectionParam("login", "user"),
"password",
DbtConnectionParam("port", default=5432),
]
conn_extra_params = [
DbtConnectionParam("dbname", "database", "postgres"),
"connect_timeout",
"role",
"search_path",
"keepalives_idle",
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"retries",
]

def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
"""Extract dbt connection details from Airflow Connection.
dbt connection details may be present as Airflow Connection attributes or in the
Connection's extras. This class' conn_params and conn_extra_params will be used
to fetch required attributes from attributes and extras respectively. If
conn_extra_params is empty, we merge parameters with all extras.
Subclasses may override this class attributes to narrow down the connection
details for a specific dbt target (like Postgres, or Redshift).
Returns:
A dictionary of dbt connection details.
"""
if "options" in conn.extra_dejson:
conn = copy(conn)
extra_dejson = conn.extra_dejson
options = extra_dejson.pop("options")
for k, v in re.findall(r"-c (\w+)=(.*)$", options):
extra_dejson[k] = v
conn.extra = json.dumps(extra_dejson)
return super().get_dbt_details_from_connection(conn)


class DbtRedshiftHook(DbtPostgresHook):
"""A hook to interact with dbt using a Redshift connection."""

conn_type = "redshift"
hook_name = "dbt Redshift Hook"
conn_extra_params = DbtPostgresHook.conn_extra_params + [
"method",
"cluster_id",
"iam_profile",
"autocreate",
"db_groups",
"ra3_node",
"connect_timeout",
"role",
"region",
]


class DbtSnowflakeHook(DbtConnectionHook):
"""A hook to interact with dbt using a Snowflake connection."""

conn_type = "snowflake"
hook_name = "dbt Snowflake Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"host",
"schema",
DbtConnectionParam(
"login",
"user",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") != "oauth",
),
DbtConnectionParam(
"login",
"oauth_client_id",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
),
DbtConnectionParam(
"password",
depends_on=lambda x: not any(
(
*(
k in x.extra_dejson
for k in ("private_key_file", "private_key_content")
),
x.extra_dejson.get("authenticator", "") == "oauth",
),
),
),
DbtConnectionParam(
"password",
"private_key_passphrase",
depends_on=lambda x: any(
k in x.extra_dejson for k in ("private_key_file", "private_key_content")
),
),
DbtConnectionParam(
"password",
"oauth_client_secret",
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
),
]
conn_extra_params = [
"warehouse",
"role",
"authenticator",
"query_tag",
"client_session_keep_alive",
"connect_timeout",
"retry_on_database_errors",
"retry_all",
"reuse_connections",
"account",
"database",
DbtConnectionParam("refresh_token", "token"),
DbtConnectionParam("private_key_file", "private_key_path"),
DbtConnectionParam("private_key_content", "private_key"),
]


class DbtBigQueryHook(DbtConnectionHook):
"""A hook to interact with dbt using a BigQuery connection."""

conn_type = "bigquery"
hook_name = "dbt BigQuery Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"schema",
]
conn_extra_params = [
DbtConnectionParam("keyfile_path", "keyfile"),
DbtConnectionParam("keyfile_dict", "keyfile_json"),
"method",
"database",
"schema",
"refresh_token",
"client_id",
"client_secret",
"token_uri",
"OPTIONAL_CONFIG",
]


class DbtSparkHook(DbtConnectionHook):
"""A hook to interact with dbt using a Spark connection."""

conn_type = "spark"
hook_name = "dbt Spark Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", conn_type),
"host",
"port",
"schema",
DbtConnectionParam("login", "user"),
DbtConnectionParam(
"password",
depends_on=lambda x: x.extra_dejson.get("method", "") == "thrift",
),
DbtConnectionParam(
"password",
"token",
depends_on=lambda x: x.extra_dejson.get("method", "") != "thrift",
),
]
conn_extra_params = []
4 changes: 2 additions & 2 deletions airflow_dbt_python/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
send_anonymous_usage_stats: Optional[bool] = None,
no_send_anonymous_usage_stats: Optional[bool] = None,
# Extra features configuration
dbt_conn_id: Optional[str] = "dbt_conn_id",
dbt_conn_id: Optional[str] = None,
profiles_conn_id: Optional[str] = None,
project_conn_id: Optional[str] = None,
do_xcom_push_artifacts: Optional[list[str]] = None,
Expand Down Expand Up @@ -305,7 +305,7 @@ def command(self) -> str:
class DbtSeedOperator(DbtBaseOperator):
"""Executes a dbt seed command.

The seed command loads csv files into the the given target. The
The seed command loads csv files into the given target. The
documentation for the dbt command can be found here:
https://docs.getdbt.com/reference/commands/seed.
"""
Expand Down
6 changes: 3 additions & 3 deletions airflow_dbt_python/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,10 @@ def create_dbt_profile(
raw_profiles = {}

if extra_targets:
profile = raw_profiles.setdefault(self.profile_name, {})
outputs = profile.setdefault("outputs", {})
raw_profile = raw_profiles.setdefault(self.profile_name, {})
outputs = raw_profile.setdefault("outputs", {})
outputs.setdefault("target", self.target)
profile["outputs"] = {**outputs, **extra_targets}
raw_profile["outputs"] = {**outputs, **extra_targets}

profile = Profile.from_raw_profile_info(
raw_profile=raw_profiles.get(
Expand Down
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,9 @@ def airflow_conns(database):
connections are set for now as our testing database is postgres.
"""
uris = (
f"postgres://{database.user}:{database.password}@{database.host}:{database.port}/public?database={database.dbname}",
f"postgres://{database.user}:{database.password}@{database.host}:{database.port}/public",
)
ids = (
"dbt_test_postgres_1",
database.dbname,
f"postgres://{database.user}:{database.password}@{database.host}:{database.port}/public?dbname={database.dbname}",
)
ids = ("dbt_test_postgres_1",)
session = settings.Session()

connections = []
Expand Down
Loading