Skip to content

Commit 637c77b

Browse files
committed
fix: replace the set of DbtConnectionParam with conditions by a single DbtConnectionConditionParam
1 parent 23dfb30 commit 637c77b

File tree

1 file changed

+141
-58
lines changed

1 file changed

+141
-58
lines changed

airflow_dbt_python/hooks/target.py

Lines changed: 141 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import json
6+
import operator
67
import re
78
import warnings
89
from abc import ABC, ABCMeta
@@ -34,7 +35,6 @@ class DbtConnectionParam(NamedTuple):
3435
name: str
3536
store_override_name: Optional[str] = None
3637
default: Optional[Any] = None
37-
depends_on: Callable[[Connection], bool] = lambda x: True
3838

3939
@property
4040
def override_name(self):
@@ -50,6 +50,88 @@ def override_name(self):
5050
return self.store_override_name
5151

5252

53+
class ResolverCondition(NamedTuple):
54+
"""Condition for resolving connection parameters based on extra_dejson.
55+
56+
Attributes:
57+
condition_key: The key in `extra_dejson` to check.
58+
comparison_operator: A function to compare the actual value
59+
with the expected value.
60+
expected: The expected value for the condition to be satisfied.
61+
"""
62+
63+
condition_key: str
64+
comparison_operator: Callable[[Any, Any], bool]
65+
expected: Any
66+
67+
68+
class ResolverResult(NamedTuple):
69+
"""Result of resolving a connection parameter.
70+
71+
Attributes:
72+
override_name: The name to override the parameter with, if applicable.
73+
default: The default value to use if no value is found.
74+
"""
75+
76+
override_name: Optional[str]
77+
default: Optional[Any]
78+
79+
80+
def make_extra_dejson_resolver(
81+
*conditions: tuple[ResolverCondition, ResolverResult],
82+
default: ResolverResult = ResolverResult(None, None),
83+
) -> Callable[[Connection], ResolverResult]:
84+
"""Creates a resolver function for override names and defaults.
85+
86+
Args:
87+
*conditions: A sequence of conditions and their corresponding results.
88+
default: The default result if no condition is met.
89+
90+
Returns:
91+
A function that takes a `Connection` object and returns
92+
the appropriate `ResolverResult`.
93+
"""
94+
95+
def extra_dejson_resolver(conn: Connection) -> ResolverResult:
96+
for (
97+
condition_key,
98+
comparison_operator,
99+
expected,
100+
), resolver_result in conditions:
101+
if comparison_operator(conn.extra_dejson.get(condition_key), expected):
102+
return resolver_result
103+
return default
104+
105+
return extra_dejson_resolver
106+
107+
108+
class DbtConnectionConditionParam(NamedTuple):
109+
"""Connection parameter with dynamic override name and default value.
110+
111+
Attributes:
112+
name: The original name of the parameter.
113+
resolver: A function that resolves the parameter
114+
name and default value based on the connection's `extra_dejson`.
115+
"""
116+
117+
name: str
118+
resolver: Callable[[Connection], ResolverResult]
119+
120+
def resolve(self, connection: Connection) -> ResolverResult:
121+
"""Resolves the override name and default value for this parameter.
122+
123+
Args:
124+
connection: The Airflow connection object.
125+
126+
Returns:
127+
The resolved override name and default value.
128+
"""
129+
override_name, default = self.resolver(connection)
130+
if override_name is None:
131+
return ResolverResult(self.name, default)
132+
return ResolverResult(override_name, default)
133+
134+
53135
class DbtConnectionHookMeta(ABCMeta):
54136
"""A hook metaclass to collect all subclasses of DbtConnectionHook."""
55137

@@ -78,15 +160,15 @@ class DbtConnectionHook(BaseHook, ABC, metaclass=DbtConnectionHookMeta):
78160
hook_name = "dbt Hook"
79161
airflow_conn_types: tuple[str, ...] = ()
80162

81-
conn_params: list[Union[DbtConnectionParam, str]] = [
163+
conn_params: list[Union[DbtConnectionParam, DbtConnectionConditionParam, str]] = [
82164
DbtConnectionParam("conn_type", "type"),
83165
"host",
84166
"schema",
85167
"login",
86168
"password",
87169
"port",
88170
]
89-
conn_extra_params: list[Union[DbtConnectionParam, str]] = []
171+
conn_extra_params: list[Union[DbtConnectionParam, DbtConnectionConditionParam, str]] = []
90172

91173
def __init__(
92174
self,
@@ -139,10 +221,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
139221
dbt_details = {"type": self.conn_type}
140222
for param in self.conn_params:
141223
if isinstance(param, DbtConnectionParam):
142-
if not param.depends_on(conn):
143-
continue
144224
key = param.override_name
145225
value = getattr(conn, param.name, param.default)
226+
elif isinstance(param, DbtConnectionConditionParam):
227+
key, default = param.resolve(conn)
228+
value = getattr(conn, param.name, default)
146229
else:
147230
key = param
148231
value = getattr(conn, key, None)
@@ -159,10 +242,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
159242

160243
for param in self.conn_extra_params:
161244
if isinstance(param, DbtConnectionParam):
162-
if not param.depends_on(conn):
163-
continue
164245
key = param.override_name
165246
value = extra.get(param.name, param.default)
247+
elif isinstance(param, DbtConnectionConditionParam):
248+
key, default = param.resolve(conn)
249+
value = extra.get(param.name, default)
166250
else:
167251
key = param
168252
value = extra.get(key, None)
@@ -220,7 +304,8 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
220304
conn = copy(conn)
221305
extra_dejson = conn.extra_dejson
222306
options = extra_dejson.pop("options")
223-
# This is to pass options (e.g. `-c search_path=myschema`) to dbt in the required form
307+
# This is to pass options (e.g. `-c search_path=myschema`) to dbt
308+
# in the required form
224309
for k, v in re.findall(r"-c (\w+)=(.*)$", options):
225310
extra_dejson[k] = v
226311
conn.extra = json.dumps(extra_dejson)
@@ -235,11 +320,14 @@ class DbtRedshiftHook(DbtPostgresHook):
235320
airflow_conn_types = (conn_type,)
236321

237322
conn_extra_params = DbtPostgresHook.conn_extra_params + [
238-
"method",
239-
DbtConnectionParam(
323+
DbtConnectionConditionParam(
240324
"method",
241-
default="iam",
242-
depends_on=lambda x: x.extra_dejson.get("iam_profile") is not None,
325+
resolver=make_extra_dejson_resolver(
326+
(
327+
ResolverCondition("iam_profile", operator.is_not, None),
328+
ResolverResult(None, "iam"),
329+
)
330+
),
243331
),
244332
"cluster_id",
245333
"iam_profile",
@@ -262,40 +350,33 @@ class DbtSnowflakeHook(DbtConnectionHook):
262350
conn_params = [
263351
"host",
264352
"schema",
265-
DbtConnectionParam(
266-
"login",
267-
"user",
268-
depends_on=lambda x: x.extra_dejson.get("authenticator", "") != "oauth",
269-
),
270-
DbtConnectionParam(
353+
DbtConnectionConditionParam(
271354
"login",
272-
"oauth_client_id",
273-
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
274-
),
275-
DbtConnectionParam(
276-
"password",
277-
depends_on=lambda x: not any(
355+
resolver=make_extra_dejson_resolver(
278356
(
279-
*(
280-
k in x.extra_dejson
281-
for k in ("private_key_file", "private_key_content")
282-
),
283-
x.extra_dejson.get("authenticator", "") == "oauth",
357+
ResolverCondition("authenticator", operator.eq, "oauth"),
358+
ResolverResult("oauth_client_id", None),
284359
),
360+
default=ResolverResult("user", None),
285361
),
286362
),
287-
DbtConnectionParam(
363+
DbtConnectionConditionParam(
288364
"password",
289-
"private_key_passphrase",
290-
depends_on=lambda x: any(
291-
k in x.extra_dejson for k in ("private_key_file", "private_key_content")
365+
resolver=make_extra_dejson_resolver(
366+
(
367+
ResolverCondition("authenticator", operator.eq, "oauth"),
368+
ResolverResult("oauth_client_secret", None),
369+
),
370+
(
371+
ResolverCondition("private_key_file", operator.is_not, None),
372+
ResolverResult("private_key_passphrase", None),
373+
),
374+
(
375+
ResolverCondition("private_key_content", operator.is_not, None),
376+
ResolverResult("private_key_passphrase", None),
377+
),
292378
),
293379
),
294-
DbtConnectionParam(
295-
"password",
296-
"oauth_client_secret",
297-
depends_on=lambda x: x.extra_dejson.get("authenticator", "") == "oauth",
298-
),
299380
]
300381
conn_extra_params = [
301382
"warehouse",
@@ -327,20 +408,22 @@ class DbtBigQueryHook(DbtConnectionHook):
327408
]
328409
conn_extra_params = [
329410
DbtConnectionParam("method", default="oauth"),
330-
DbtConnectionParam(
411+
DbtConnectionConditionParam(
331412
"method",
332-
default="oauth-secrets",
333-
depends_on=lambda x: x.extra_dejson.get("refresh_token") is not None,
334-
),
335-
DbtConnectionParam(
336-
"method",
337-
default="service-account-json",
338-
depends_on=lambda x: x.extra_dejson.get("keyfile_dict") is not None,
339-
),
340-
DbtConnectionParam(
341-
"method",
342-
default="service-account",
343-
depends_on=lambda x: x.extra_dejson.get("key_path") is not None,
413+
resolver=make_extra_dejson_resolver(
414+
(
415+
ResolverCondition("refresh_token", operator.is_not, None),
416+
ResolverResult(None, "oauth-secrets"),
417+
),
418+
(
419+
ResolverCondition("keyfile_dict", operator.is_not, None),
420+
ResolverResult(None, "service-account-json"),
421+
),
422+
(
423+
ResolverCondition("key_path", operator.is_not, None),
424+
ResolverResult(None, "service-account"),
425+
),
426+
),
344427
),
345428
DbtConnectionParam("key_path", "keyfile"),
346429
DbtConnectionParam("keyfile_dict", "keyfile_json"),
@@ -366,14 +449,14 @@ class DbtSparkHook(DbtConnectionHook):
366449
"port",
367450
"schema",
368451
DbtConnectionParam("login", "user"),
369-
DbtConnectionParam(
370-
"password",
371-
depends_on=lambda x: x.extra_dejson.get("method", "") == "thrift",
372-
),
373-
DbtConnectionParam(
452+
DbtConnectionConditionParam(
374453
"password",
375-
"token",
376-
depends_on=lambda x: x.extra_dejson.get("method", "") != "thrift",
454+
resolver=make_extra_dejson_resolver(
455+
(
456+
ResolverCondition("method", operator.ne, "thrift"),
457+
ResolverResult("token", None),
458+
),
459+
),
377460
),
378461
]
379462
conn_extra_params = []

0 commit comments

Comments
 (0)