-
-
Notifications
You must be signed in to change notification settings - Fork 40
Version 3.0.0 #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Version 3.0.0 #145
Changes from 1 commit
fcff9c2
0fc61ec
ad97caa
c4a87f4
1a76ada
e2700fa
f30c9f0
ef78d19
c0542ef
20566d9
c33df56
4d6edc7
e415caa
4ba8726
364ba09
f760553
40d1f50
c666d7b
db29cd7
79a369d
ef3038f
23dfb30
aee9c9c
e136027
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,17 @@ | |
|
|
||
| import json | ||
| import logging | ||
| import os | ||
| import re | ||
| import sys | ||
| from contextlib import contextmanager | ||
| from copy import copy | ||
| from pathlib import Path | ||
| from tempfile import TemporaryDirectory | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Callable, | ||
| Dict, | ||
| Iterable, | ||
| Iterator, | ||
|
|
@@ -69,6 +73,7 @@ class DbtConnectionParam(NamedTuple): | |
| name: str | ||
| store_override_name: Optional[str] = None | ||
| default: Optional[Any] = None | ||
| depends_on: Callable[[Connection], bool] = lambda x: True | ||
|
|
||
| @property | ||
| def override_name(self): | ||
|
|
@@ -119,9 +124,8 @@ class DbtHook(BaseHook): | |
| conn_params: list[Union[DbtConnectionParam, str]] = [ | ||
| DbtConnectionParam("conn_type", "type"), | ||
| "host", | ||
| DbtConnectionParam("conn_id", "dbname"), | ||
| "schema", | ||
| DbtConnectionParam("login", "user"), | ||
| "login", | ||
| "password", | ||
| "port", | ||
| ] | ||
|
|
@@ -240,8 +244,6 @@ def run_dbt_task( | |
| nearest_project_dir = get_nearest_project_dir(config.project_dir) | ||
|
|
||
| with chdir_ctx(nearest_project_dir): | ||
| self.ensure_profiles(config) | ||
|
|
||
| with adapter_management(): | ||
| task, runtime_config = config.create_dbt_task( | ||
| extra_target, write_perf_info | ||
|
|
@@ -412,19 +414,6 @@ def setup_dbt_logging(self, debug: Optional[bool]): | |
| configured_file.setLevel("INFO") | ||
| configured_file.propagate = False | ||
|
|
||
| def ensure_profiles(self, config: BaseConfig): | ||
| """Ensure a profiles file exists.""" | ||
| if config.profiles_dir is not None: | ||
| # We expect one to exist given that we have passed a profiles_dir. | ||
| return | ||
|
|
||
| profiles_path = Path.home() / ".dbt/profiles.yml" | ||
| config.profiles_dir = str(profiles_path.parent) | ||
| if not profiles_path.exists(): | ||
| profiles_path.parent.mkdir(exist_ok=True) | ||
| with profiles_path.open("w", encoding="utf-8") as f: | ||
| f.write("flags:\n send_anonymous_usage_stats: false\n") | ||
|
|
||
| def get_dbt_target_from_connection( | ||
| self, target: Optional[str] | ||
| ) -> Optional[dict[str, Any]]: | ||
|
|
@@ -454,11 +443,14 @@ def get_dbt_target_from_connection( | |
| ) | ||
| return None | ||
|
|
||
| details = self.get_dbt_details_from_connection(conn) | ||
| db_hook_class = self.get_db_hook_class(conn) | ||
|
|
||
| details = db_hook_class.get_dbt_details_from_connection(conn) | ||
|
|
||
| return {conn_id: details} | ||
|
|
||
| def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | ||
| @classmethod | ||
| def get_dbt_details_from_connection(cls, conn: Connection) -> dict[str, Any]: | ||
tomasfarias marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Extract dbt connection details from Airflow Connection. | ||
|
|
||
| dbt connection details may be present as Airflow Connection attributes or in the | ||
|
|
@@ -476,8 +468,10 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | |
| A dictionary of dbt connection details. | ||
| """ | ||
| dbt_details = {} | ||
| for param in self.conn_params: | ||
| for param in cls.conn_params: | ||
| if isinstance(param, DbtConnectionParam): | ||
| if not param.depends_on(conn): | ||
| continue | ||
| key = param.override_name | ||
| value = getattr(conn, param.name, param.default) | ||
| else: | ||
|
|
@@ -491,11 +485,13 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | |
|
|
||
| extra = conn.extra_dejson | ||
|
|
||
| if not self.conn_extra_params: | ||
| if not cls.conn_extra_params: | ||
| return {**dbt_details, **extra} | ||
|
|
||
| for param in self.conn_extra_params: | ||
| for param in cls.conn_extra_params: | ||
| if isinstance(param, DbtConnectionParam): | ||
| if not param.depends_on(conn): | ||
| continue | ||
| key = param.override_name | ||
| value = extra.get(param.name, param.default) | ||
| else: | ||
|
|
@@ -509,43 +505,91 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: | |
|
|
||
| return dbt_details | ||
|
|
||
| @classmethod | ||
| def get_db_hook_class(cls, conn: Connection) -> type[DbtHook]: | ||
| """Get a dbt hook class depend on Airflow connection type.""" | ||
| known_dbt_hooks = ( | ||
| DbtPostgresHook, | ||
| DbtRedshiftHook, | ||
| DbtSnowflakeHook, | ||
| DbtBigQueryHook, | ||
| DbtSparkHook, | ||
| ) | ||
| for hook in known_dbt_hooks: | ||
| if hook.conn_type == conn.conn_type: | ||
| return hook | ||
| return DbtHook | ||
|
|
||
|
|
||
| class DbtPostgresHook(DbtHook): | ||
| """A hook to interact with dbt using a Postgres connection.""" | ||
|
|
||
| conn_type = "postgres" | ||
| hook_name = "dbt Postgres Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", "postgres"), | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| "schema", | ||
| DbtConnectionParam("schema", default="public"), | ||
| DbtConnectionParam("login", "user"), | ||
| "password", | ||
| "port", | ||
| DbtConnectionParam("port", default=5432), | ||
| ] | ||
| conn_extra_params = [ | ||
| "dbname", | ||
| "threads", | ||
| "keepalives_idle", | ||
| DbtConnectionParam("dbname", "database", "postgres"), | ||
| "connect_timeout", | ||
| "retries", | ||
| "search_path", | ||
| "role", | ||
| "search_path", | ||
| "keepalives_idle", | ||
| "sslmode", | ||
| "sslcert", | ||
| "sslkey", | ||
| "sslrootcert", | ||
| "retries", | ||
| ] | ||
|
|
||
| @classmethod | ||
| def get_dbt_details_from_connection(cls, conn: Connection) -> dict[str, Any]: | ||
| """Extract dbt connection details from Airflow Connection. | ||
|
|
||
| dbt connection details may be present as Airflow Connection attributes or in the | ||
| Connection's extras. This class' conn_params and conn_extra_params will be used | ||
| to fetch required attributes from attributes and extras respectively. If | ||
| conn_extra_params is empty, we merge parameters with all extras. | ||
|
|
||
| Subclasses may override this class attributes to narrow down the connection | ||
| details for a specific dbt target (like Postgres, or Redshift). | ||
|
|
||
| Args: | ||
| conn: The Airflow Connection to extract dbt connection details from. | ||
|
|
||
| Returns: | ||
| A dictionary of dbt connection details. | ||
| """ | ||
| if "options" in conn.extra_dejson: | ||
| conn = copy(conn) | ||
| extra_dejson = conn.extra_dejson | ||
| options = extra_dejson.pop("options") | ||
| for k, v in re.findall(r"-c (\w+)=(.*)$", options): | ||
|
||
| extra_dejson[k] = v | ||
| conn.extra = json.dumps(extra_dejson) | ||
| return super().get_dbt_details_from_connection(conn) | ||
|
|
||
|
|
||
| class DbtRedshiftHook(DbtPostgresHook): | ||
| """A hook to interact with dbt using a Redshift connection.""" | ||
|
|
||
| conn_type = "redshift" | ||
| hook_name = "dbt Redshift Hook" | ||
| conn_extra_params = DbtPostgresHook.conn_extra_params + [ | ||
| "ra3_node", | ||
| "method", | ||
| "cluster_id", | ||
| "iam_profile", | ||
| "iam_duration_secons", | ||
| "autocreate", | ||
| "db_groups", | ||
| "ra3_node", | ||
| "connect_timeout", | ||
| "role", | ||
| "region", | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -555,22 +599,101 @@ class DbtSnowflakeHook(DbtHook): | |
| conn_type = "snowflake" | ||
| hook_name = "dbt Snowflake Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", "postgres"), | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| "schema", | ||
| DbtConnectionParam("login", "user"), | ||
| "password", | ||
| DbtConnectionParam( | ||
| "login", | ||
| "user", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") != "oauth", | ||
|
||
| ), | ||
| DbtConnectionParam( | ||
| "login", | ||
| "oauth_client_id", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| depends_on=lambda x: not any( | ||
| ( | ||
| *(k in x.extra_dejson for k in ("private_key_file", "private_key_content")), | ||
|
||
| x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| ), | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "private_key_passphrase", | ||
| depends_on=lambda x: any( | ||
| k in x.extra_dejson for k in ("private_key_file", "private_key_content") | ||
tomasfarias marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ), | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "oauth_client_secret", | ||
| depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth", | ||
| ), | ||
| ] | ||
| conn_extra_params = [ | ||
| "account", | ||
| "role", | ||
| "database", | ||
| "warehouse", | ||
| "threads", | ||
| "client_session_keep_alive", | ||
| "role", | ||
| "authenticator", | ||
| "query_tag", | ||
| "connect_retries", | ||
| "client_session_keep_alive", | ||
| "connect_timeout", | ||
| "retry_on_database_errors", | ||
| "retry_all", | ||
| "reuse_connections", | ||
| "account", | ||
| "database", | ||
| DbtConnectionParam("refresh_token", "token"), | ||
| DbtConnectionParam("private_key_file", "private_key_path"), | ||
| DbtConnectionParam("private_key_content", "private_key"), | ||
| ] | ||
|
|
||
|
|
||
| class DbtBigQueryHook(DbtHook): | ||
| """A hook to interact with dbt using a BigQuery connection.""" | ||
|
|
||
| conn_type = "bigquery" | ||
| hook_name = "dbt BigQuery Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "schema", | ||
| ] | ||
| conn_extra_params = [ | ||
| DbtConnectionParam("keyfile_path", "keyfile"), | ||
| DbtConnectionParam("keyfile_dict", "keyfile_json"), | ||
| "method", | ||
| "database", | ||
| "schema", | ||
| "refresh_token", | ||
| "client_id", | ||
| "client_secret", | ||
| "token_uri", | ||
| "OPTIONAL_CONFIG", | ||
| ] | ||
|
|
||
|
|
||
| class DbtSparkHook(DbtHook): | ||
| """A hook to interact with dbt using a Spark connection.""" | ||
|
|
||
| conn_type = "spark" | ||
| hook_name = "dbt Spark Hook" | ||
| conn_params = [ | ||
| DbtConnectionParam("conn_type", "type", conn_type), | ||
| "host", | ||
| "port", | ||
| "schema", | ||
| DbtConnectionParam("login", "user"), | ||
| DbtConnectionParam( | ||
| "password", | ||
| depends_on=lambda x: x.extra_dejson.get("method", "") == "thrift", | ||
| ), | ||
| DbtConnectionParam( | ||
| "password", | ||
| "token", | ||
| depends_on=lambda x: x.extra_dejson.get("method", "") != "thrift", | ||
| ), | ||
| ] | ||
| conn_extra_params = [] | ||
Uh oh!
There was an error while loading. Please reload this page.