33from __future__ import annotations
44
55import json
6+ import operator
67import re
78import warnings
89from abc import ABC , ABCMeta
@@ -34,7 +35,6 @@ class DbtConnectionParam(NamedTuple):
3435 name : str
3536 store_override_name : Optional [str ] = None
3637 default : Optional [Any ] = None
37- depends_on : Callable [[Connection ], bool ] = lambda x : True
3838
3939 @property
4040 def override_name (self ):
@@ -50,6 +50,88 @@ def override_name(self):
5050 return self .store_override_name
5151
5252
53+ class ResolverCondition (NamedTuple ):
54+ """Condition for resolving connection parameters based on extra_dejson.
55+
56+ Attributes:
57+ condition_key: The key in `extra_dejson` to check.
58+ comparison_operator: A function to compare the actual value
59+ with the expected value.
60+ expected: The expected value for the condition to be satisfied.
61+ """
62+
63+ condition_key : str
64+ comparison_operator : Callable [[Any , Any ], bool ]
65+ expected : Any
66+
67+
68+ class ResolverResult (NamedTuple ):
69+ """Result of resolving a connection parameter.
70+
71+ Attributes:
72+ override_name: The name to override the parameter with, if applicable.
73+ default: The default value to use if no value is found.
74+ """
75+
76+ override_name : Optional [str ]
77+ default : Optional [Any ]
78+
79+
80+ def make_extra_dejson_resolver (
81+ * conditions : tuple [ResolverCondition , ResolverResult ],
82+ default : ResolverResult = ResolverResult (None , None ),
83+ ) -> Callable [[Connection ], ResolverResult ]:
84+ """Creates a resolver function for override names and defaults.
85+
86+ Args:
87+ *conditions: A sequence of conditions and their corresponding results.
88+ default: The default result if no condition is met.
89+
90+ Returns:
91+ A function that takes a `Connection` object and returns
92+ the appropriate `ResolverResult`.
93+ """
94+
95+ def extra_dejson_resolver (conn : Connection ) -> ResolverResult :
96+ for (
97+ condition_key ,
98+ comparison_operator ,
99+ expected ,
100+ ), resolver_result in conditions :
101+ if comparison_operator (conn .extra_dejson .get (condition_key ), expected ):
102+ return resolver_result
103+ return default
104+
105+ return extra_dejson_resolver
106+
107+
108+ class DbtConnectionConditionParam (NamedTuple ):
109+ """Connection parameter with dynamic override name and default value.
110+
111+ Attributes:
112+ name: The original name of the parameter.
113+ resolver: A function that resolves the parameter
114+ name and default value based on the connection's `extra_dejson`.
115+ """
116+
117+ name : str
118+ resolver : Callable [[Connection ], ResolverResult ]
119+
120+ def resolve (self , connection : Connection ) -> ResolverResult :
121+ """Resolves the override name and default value for this parameter.
122+
123+ Args:
124+ connection: The Airflow connection object.
125+
126+ Returns:
127+ The resolved override name and default value.
128+ """
129+ override_name , default = self .resolver (connection )
130+ if override_name is None :
131+ return ResolverResult (self .name , default )
132+ return ResolverResult (override_name , default )
133+
134+
53135class DbtConnectionHookMeta (ABCMeta ):
54136 """A hook metaclass to collect all subclasses of DbtConnectionHook."""
55137
@@ -78,15 +160,15 @@ class DbtConnectionHook(BaseHook, ABC, metaclass=DbtConnectionHookMeta):
78160 hook_name = "dbt Hook"
79161 airflow_conn_types : tuple [str , ...] = ()
80162
81- conn_params : list [Union [DbtConnectionParam , str ]] = [
163+ conn_params : list [Union [DbtConnectionParam , DbtConnectionConditionParam , str ]] = [
82164 DbtConnectionParam ("conn_type" , "type" ),
83165 "host" ,
84166 "schema" ,
85167 "login" ,
86168 "password" ,
87169 "port" ,
88170 ]
89- conn_extra_params : list [Union [DbtConnectionParam , str ]] = []
171+ conn_extra_params : list [Union [DbtConnectionParam , DbtConnectionConditionParam , str ]] = []
90172
91173 def __init__ (
92174 self ,
@@ -139,10 +221,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
139221 dbt_details = {"type" : self .conn_type }
140222 for param in self .conn_params :
141223 if isinstance (param , DbtConnectionParam ):
142- if not param .depends_on (conn ):
143- continue
144224 key = param .override_name
145225 value = getattr (conn , param .name , param .default )
226+ elif isinstance (param , DbtConnectionConditionParam ):
227+ key , default = param .resolve (conn )
228+ value = getattr (conn , param .name , default )
146229 else :
147230 key = param
148231 value = getattr (conn , key , None )
@@ -159,10 +242,11 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
159242
160243 for param in self .conn_extra_params :
161244 if isinstance (param , DbtConnectionParam ):
162- if not param .depends_on (conn ):
163- continue
164245 key = param .override_name
165246 value = extra .get (param .name , param .default )
247+ elif isinstance (param , DbtConnectionConditionParam ):
248+ key , default = param .resolve (conn )
249+ value = extra .get (param .name , default )
166250 else :
167251 key = param
168252 value = extra .get (key , None )
@@ -220,7 +304,8 @@ def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
220304 conn = copy (conn )
221305 extra_dejson = conn .extra_dejson
222306 options = extra_dejson .pop ("options" )
223- # This is to pass options (e.g. `-c search_path=myschema`) to dbt in the required form
307+ # This is to pass options (e.g. `-c search_path=myschema`) to dbt
308+ # in the required form
224309 for k , v in re .findall (r"-c (\w+)=(.*)$" , options ):
225310 extra_dejson [k ] = v
226311 conn .extra = json .dumps (extra_dejson )
@@ -235,11 +320,14 @@ class DbtRedshiftHook(DbtPostgresHook):
235320 airflow_conn_types = (conn_type ,)
236321
237322 conn_extra_params = DbtPostgresHook .conn_extra_params + [
238- "method" ,
239- DbtConnectionParam (
323+ DbtConnectionConditionParam (
240324 "method" ,
241- default = "iam" ,
242- depends_on = lambda x : x .extra_dejson .get ("iam_profile" ) is not None ,
325+ resolver = make_extra_dejson_resolver (
326+ (
327+ ResolverCondition ("iam_profile" , operator .is_not , None ),
328+ ResolverResult (None , "iam" ),
329+ )
330+ ),
243331 ),
244332 "cluster_id" ,
245333 "iam_profile" ,
@@ -262,40 +350,33 @@ class DbtSnowflakeHook(DbtConnectionHook):
262350 conn_params = [
263351 "host" ,
264352 "schema" ,
265- DbtConnectionParam (
266- "login" ,
267- "user" ,
268- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) != "oauth" ,
269- ),
270- DbtConnectionParam (
353+ DbtConnectionConditionParam (
271354 "login" ,
272- "oauth_client_id" ,
273- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
274- ),
275- DbtConnectionParam (
276- "password" ,
277- depends_on = lambda x : not any (
355+ resolver = make_extra_dejson_resolver (
278356 (
279- * (
280- k in x .extra_dejson
281- for k in ("private_key_file" , "private_key_content" )
282- ),
283- x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
357+ ResolverCondition ("authenticator" , operator .eq , "oauth" ),
358+ ResolverResult ("oauth_client_id" , None ),
284359 ),
360+ default = ResolverResult ("user" , None ),
285361 ),
286362 ),
287- DbtConnectionParam (
363+ DbtConnectionConditionParam (
288364 "password" ,
289- "private_key_passphrase" ,
290- depends_on = lambda x : any (
291- k in x .extra_dejson for k in ("private_key_file" , "private_key_content" )
365+ resolver = make_extra_dejson_resolver (
366+ (
367+ ResolverCondition ("authenticator" , operator .eq , "oauth" ),
368+ ResolverResult ("oauth_client_secret" , None ),
369+ ),
370+ (
371+ ResolverCondition ("private_key_file" , operator .is_not , None ),
372+ ResolverResult ("private_key_passphrase" , None ),
373+ ),
374+ (
375+ ResolverCondition ("private_key_content" , operator .is_not , None ),
376+ ResolverResult ("private_key_passphrase" , None ),
377+ ),
292378 ),
293379 ),
294- DbtConnectionParam (
295- "password" ,
296- "oauth_client_secret" ,
297- depends_on = lambda x : x .extra_dejson .get ("authenticator" , "" ) == "oauth" ,
298- ),
299380 ]
300381 conn_extra_params = [
301382 "warehouse" ,
@@ -327,20 +408,22 @@ class DbtBigQueryHook(DbtConnectionHook):
327408 ]
328409 conn_extra_params = [
329410 DbtConnectionParam ("method" , default = "oauth" ),
330- DbtConnectionParam (
411+ DbtConnectionConditionParam (
331412 "method" ,
332- default = "oauth-secrets" ,
333- depends_on = lambda x : x .extra_dejson .get ("refresh_token" ) is not None ,
334- ),
335- DbtConnectionParam (
336- "method" ,
337- default = "service-account-json" ,
338- depends_on = lambda x : x .extra_dejson .get ("keyfile_dict" ) is not None ,
339- ),
340- DbtConnectionParam (
341- "method" ,
342- default = "service-account" ,
343- depends_on = lambda x : x .extra_dejson .get ("key_path" ) is not None ,
413+ resolver = make_extra_dejson_resolver (
414+ (
415+ ResolverCondition ("refresh_token" , operator .is_not , None ),
416+ ResolverResult (None , "oauth-secrets" ),
417+ ),
418+ (
419+ ResolverCondition ("keyfile_dict" , operator .is_not , None ),
420+ ResolverResult (None , "service-account-json" ),
421+ ),
422+ (
423+ ResolverCondition ("key_path" , operator .is_not , None ),
424+ ResolverResult (None , "service-account" ),
425+ ),
426+ ),
344427 ),
345428 DbtConnectionParam ("key_path" , "keyfile" ),
346429 DbtConnectionParam ("keyfile_dict" , "keyfile_json" ),
@@ -366,14 +449,14 @@ class DbtSparkHook(DbtConnectionHook):
366449 "port" ,
367450 "schema" ,
368451 DbtConnectionParam ("login" , "user" ),
369- DbtConnectionParam (
370- "password" ,
371- depends_on = lambda x : x .extra_dejson .get ("method" , "" ) == "thrift" ,
372- ),
373- DbtConnectionParam (
452+ DbtConnectionConditionParam (
374453 "password" ,
375- "token" ,
376- depends_on = lambda x : x .extra_dejson .get ("method" , "" ) != "thrift" ,
454+ resolver = make_extra_dejson_resolver (
455+ (
456+ ResolverCondition ("method" , operator .ne , "thrift" ),
457+ ResolverResult ("token" , None ),
458+ ),
459+ ),
377460 ),
378461 ]
379462 conn_extra_params = []
0 commit comments