-
-
Notifications
You must be signed in to change notification settings - Fork 40
Version 3.0.0 #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tomasfarias
merged 24 commits into
tomasfarias:master
from
millin:fix/adapter_specific_hooks
Apr 12, 2025
Merged
Version 3.0.0 #145
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
fcff9c2
fix: drop py37_copytree
millin 0fc61ec
fix: incorrect conn_id
millin ad97caa
fix: use functools.cache for get_remote
millin c4a87f4
fix: duplicate log lines
millin 1a76ada
fix: poetry deprecation warnings
millin e2700fa
fix: move remote hooks
millin f30c9f0
fix: profiles_dir must be used if specified. close #133
millin ef78d19
fix: add missing config args
millin c0542ef
feat: adapter specific hooks. close #147
millin 20566d9
feat: separate hook to get extra target based on Airflow connection
millin c33df56
fix: fixed conn_type. close #142
millin 4d6edc7
fix: support branches for DbtGitRemoteHook. close #122
millin e415caa
feat: add GCS remote hook. close #139
millin 4ba8726
fix: add missing mutually exclusive attrs
millin 364ba09
docs: update docstring
millin f760553
docs: update
millin 40d1f50
fix: mypy checks
millin c666d7b
chore: Version v3.0.0 bump
millin db29cd7
fix: rename remote to fs
millin 79a369d
fix: add comment for regexp
millin ef3038f
fix: clean up the mess
millin 23dfb30
fix: mock-gcp from PyPI
millin aee9c9c
fix: replace the set of DbtConnectionParam with conditions by a singl…
millin e136027
fix: remove unnecessary os.sep
millin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,351 @@ | ||
| """Provides a hook to get a dbt profile based on the Airflow connection.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import re | ||
| import warnings | ||
| from abc import ABC, ABCMeta | ||
| from copy import copy | ||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| ClassVar, | ||
| NamedTuple, | ||
| Optional, | ||
| Union, | ||
| ) | ||
|
|
||
| from airflow.hooks.base import BaseHook | ||
| from airflow.models.connection import Connection | ||
|
|
||
|
|
||
| class DbtConnectionParam(NamedTuple): | ||
| """A tuple indicating connection parameters relevant to dbt. | ||
| Attributes: | ||
| name: The name of the connection parameter. This name will be used to get the | ||
| parameter from an Airflow Connection or its extras. | ||
| store_override_name: A new name for the connection parameter. If not None, this | ||
| is the name used in a dbt profiles. | ||
| default: A default value if the parameter is not found. | ||
| """ | ||
|
|
||
| name: str | ||
| store_override_name: Optional[str] = None | ||
| default: Optional[Any] = None | ||
| depends_on: Callable[[Connection], bool] = lambda x: True | ||
|
|
||
| @property | ||
| def override_name(self): | ||
| """Returns the override_name if defined, otherwise defaults to name. | ||
| >>> DbtConnectionParam("login", "user").override_name | ||
| 'user' | ||
| >>> DbtConnectionParam("port").override_name | ||
| 'port' | ||
| """ | ||
| if self.store_override_name is None: | ||
| return self.name | ||
| return self.store_override_name | ||
|
|
||
|
|
||
| class DbtConnectionHookMeta(ABCMeta): | ||
| """A hook metaclass to collect all subclasses of DbtConnectionHook.""" | ||
|
|
||
| _dbt_hooks_by_conn_type: ClassVar[dict[str, DbtConnectionHookMeta]] = {} | ||
| conn_type: str | ||
|
|
||
| def __new__(cls, name, bases, attrs, **kwargs) -> DbtConnectionHookMeta: | ||
| """Adds each DbtConnectionHook subclass to the dict based on its conn_type.""" | ||
| new_hook_cls = super().__new__(cls, name, bases, attrs) | ||
| if new_hook_cls.conn_type in cls._dbt_hooks_by_conn_type: | ||
| warnings.warn( | ||
| f"The connection type `{new_hook_cls.conn_type}`" | ||
| f" has been overwritten by `{new_hook_cls}`", | ||
| UserWarning, | ||
| stacklevel=1, | ||
| ) | ||
|
|
||
| cls._dbt_hooks_by_conn_type[new_hook_cls.conn_type] = new_hook_cls | ||
| return new_hook_cls | ||
|
|
||
|
|
||
| class DbtConnectionHook(BaseHook, ABC, metaclass=DbtConnectionHookMeta): | ||
| """A hook to get a dbt profile based on the Airflow connection.""" | ||
|
|
||
| conn_type = "dbt" | ||
| hook_name = "dbt Hook" | ||
|
|
||
| conn_params: list[Union[DbtConnectionParam, str]] = [ | ||
| DbtConnectionParam("conn_type", "type"), | ||
| "host", | ||
| "schema", | ||
| "login", | ||
| "password", | ||
| "port", | ||
| ] | ||
| conn_extra_params: list[Union[DbtConnectionParam, str]] = [] | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args, | ||
| conn: Connection, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(*args, **kwargs) | ||
| self.conn = conn | ||
|
|
||
| @classmethod | ||
| def get_db_conn_hook(cls, conn_id: str) -> DbtConnectionHook: | ||
| """Get a dbt hook class depend on Airflow connection type.""" | ||
| conn = cls.get_connection(conn_id) | ||
|
|
||
| if hook_cls := cls._dbt_hooks_by_conn_type.get(conn.conn_type): | ||
| return hook_cls(conn=conn) | ||
| raise KeyError( | ||
| f"There are no DbtConnectionHook subclasses with conn_type={conn.conn_type}" | ||
| ) | ||
|
|
||
| def get_dbt_target_from_connection(self) -> Optional[dict[str, Any]]: | ||
| """Return a dictionary of connection details to use as a dbt target. | ||
| The connection details are fetched from an Airflow connection identified by | ||
| self.dbt_conn_id. | ||
| Returns: | ||
| A dictionary with a configuration for a dbt target, or None if a matching | ||
| Airflow connection is not found for given dbt target. | ||
| """ | ||
| details = self.get_dbt_details_from_connection(self.conn) | ||
|
|
||
| return {self.conn.conn_id: details} | ||
|
|
||
| def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | ||
| """Extract dbt connection details from Airflow Connection. | ||
| dbt connection details may be present as Airflow Connection attributes or in the | ||
| Connection's extras. This class' conn_params and conn_extra_params will be used | ||
| to fetch required attributes from attributes and extras respectively. If | ||
| conn_extra_params is empty, we merge parameters with all extras. | ||
| Subclasses may override this class attributes to narrow down the connection | ||
| details for a specific dbt target (like Postgres, or Redshift). | ||
| Returns: | ||
| A dictionary of dbt connection details. | ||
| """ | ||
| dbt_details = {} | ||
| for param in self.conn_params: | ||
| if isinstance(param, DbtConnectionParam): | ||
| if not param.depends_on(conn): | ||
| continue | ||
| key = param.override_name | ||
| value = getattr(conn, param.name, param.default) | ||
| else: | ||
| key = param | ||
| value = getattr(conn, key, None) | ||
|
|
||
| if value is None: | ||
| continue | ||
|
|
||
| dbt_details[key] = value | ||
|
|
||
| extra = conn.extra_dejson | ||
|
|
||
| if not self.conn_extra_params: | ||
| return {**dbt_details, **extra} | ||
|
|
||
| for param in self.conn_extra_params: | ||
| if isinstance(param, DbtConnectionParam): | ||
| if not param.depends_on(conn): | ||
| continue | ||
| key = param.override_name | ||
| value = extra.get(param.name, param.default) | ||
| else: | ||
| key = param | ||
| value = extra.get(key, None) | ||
|
|
||
| if value is None: | ||
| continue | ||
|
|
||
| dbt_details[key] = value | ||
|
|
||
| return dbt_details | ||
|
|
||
|
|
||
| class DbtPostgresHook(DbtConnectionHook): | ||
| """A hook to interact with dbt using a Postgres connection.""" | ||
|
|
||
| conn_type = "postgres" | ||
| hook_name = "dbt Postgres Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| DbtConnectionParam("schema", default="public"), | ||
| DbtConnectionParam("login", "user"), | ||
| "password", | ||
| DbtConnectionParam("port", default=5432), | ||
| ] | ||
| conn_extra_params = [ | ||
| DbtConnectionParam("dbname", "database", "postgres"), | ||
| "connect_timeout", | ||
| "role", | ||
| "search_path", | ||
| "keepalives_idle", | ||
| "sslmode", | ||
| "sslcert", | ||
| "sslkey", | ||
| "sslrootcert", | ||
| "retries", | ||
| ] | ||
|
|
||
| def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | ||
| """Extract dbt connection details from Airflow Connection. | ||
| dbt connection details may be present as Airflow Connection attributes or in the | ||
| Connection's extras. This class' conn_params and conn_extra_params will be used | ||
| to fetch required attributes from attributes and extras respectively. If | ||
| conn_extra_params is empty, we merge parameters with all extras. | ||
| Subclasses may override this class attributes to narrow down the connection | ||
| details for a specific dbt target (like Postgres, or Redshift). | ||
| Returns: | ||
| A dictionary of dbt connection details. | ||
| """ | ||
| if "options" in conn.extra_dejson: | ||
| conn = copy(conn) | ||
| extra_dejson = conn.extra_dejson | ||
| options = extra_dejson.pop("options") | ||
| for k, v in re.findall(r"-c (\w+)=(.*)$", options): | ||
| extra_dejson[k] = v | ||
| conn.extra = json.dumps(extra_dejson) | ||
| return super().get_dbt_details_from_connection(conn) | ||
|
|
||
|
|
||
| class DbtRedshiftHook(DbtPostgresHook): | ||
| """A hook to interact with dbt using a Redshift connection.""" | ||
|
|
||
| conn_type = "redshift" | ||
| hook_name = "dbt Redshift Hook" | ||
| conn_extra_params = DbtPostgresHook.conn_extra_params + [ | ||
| "method", | ||
| "cluster_id", | ||
| "iam_profile", | ||
| "autocreate", | ||
| "db_groups", | ||
| "ra3_node", | ||
| "connect_timeout", | ||
| "role", | ||
| "region", | ||
| ] | ||
|
|
||
|
|
||
| class DbtSnowflakeHook(DbtConnectionHook): | ||
| """A hook to interact with dbt using a Snowflake connection.""" | ||
|
|
||
| conn_type = "snowflake" | ||
| hook_name = "dbt Snowflake Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| "schema", | ||
| DbtConnectionParam( | ||
| "login", | ||
| "user", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") != "oauth", | ||
tomasfarias marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ), | ||
| DbtConnectionParam( | ||
| "login", | ||
| "oauth_client_id", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| depends_on=lambda x: not any( | ||
| ( | ||
| *( | ||
| k in x.extra_dejson | ||
| for k in ("private_key_file", "private_key_content") | ||
| ), | ||
| x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| ), | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "private_key_passphrase", | ||
| depends_on=lambda x: any( | ||
| k in x.extra_dejson for k in ("private_key_file", "private_key_content") | ||
| ), | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "oauth_client_secret", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| ] | ||
| conn_extra_params = [ | ||
| "warehouse", | ||
| "role", | ||
| "authenticator", | ||
| "query_tag", | ||
| "client_session_keep_alive", | ||
| "connect_timeout", | ||
| "retry_on_database_errors", | ||
| "retry_all", | ||
| "reuse_connections", | ||
| "account", | ||
| "database", | ||
| DbtConnectionParam("refresh_token", "token"), | ||
| DbtConnectionParam("private_key_file", "private_key_path"), | ||
| DbtConnectionParam("private_key_content", "private_key"), | ||
| ] | ||
|
|
||
|
|
||
| class DbtBigQueryHook(DbtConnectionHook): | ||
| """A hook to interact with dbt using a BigQuery connection.""" | ||
|
|
||
| conn_type = "bigquery" | ||
| hook_name = "dbt BigQuery Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "schema", | ||
| ] | ||
| conn_extra_params = [ | ||
| DbtConnectionParam("keyfile_path", "keyfile"), | ||
| DbtConnectionParam("keyfile_dict", "keyfile_json"), | ||
| "method", | ||
| "database", | ||
| "schema", | ||
| "refresh_token", | ||
| "client_id", | ||
| "client_secret", | ||
| "token_uri", | ||
| "OPTIONAL_CONFIG", | ||
| ] | ||
|
|
||
|
|
||
| class DbtSparkHook(DbtConnectionHook): | ||
| """A hook to interact with dbt using a Spark connection.""" | ||
|
|
||
| conn_type = "spark" | ||
| hook_name = "dbt Spark Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| "port", | ||
| "schema", | ||
| DbtConnectionParam("login", "user"), | ||
| DbtConnectionParam( | ||
| "password", | ||
| depends_on=lambda x: x.extra_dejson.get("method", "") == "thrift", | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "token", | ||
| depends_on=lambda x: x.extra_dejson.get("method", "") != "thrift", | ||
| ), | ||
| ] | ||
| conn_extra_params = [] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
praise: Okay, this is no longer a classmethod, safe to ignore my comment on this topic. I think this is now fine.