Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions airflow/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class SqlSensor(BaseSensorOperator):
:type failure: Optional<Callable[[Any], bool]>
:param fail_on_empty: Explicitly fail on no rows returned.
:type fail_on_empty: bool
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:type hook_params: dict
"""

template_fields: Iterable[str] = ('sql',)
Expand All @@ -58,14 +61,24 @@ class SqlSensor(BaseSensorOperator):
ui_color = '#7c7287'

def __init__(
self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False, **kwargs
self,
*,
conn_id,
sql,
parameters=None,
success=None,
failure=None,
fail_on_empty=False,
hook_params=None,
**kwargs,
):
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
self.success = success
self.failure = failure
self.fail_on_empty = fail_on_empty
self.hook_params = hook_params
super().__init__(**kwargs)

def _get_hook(self):
Expand All @@ -90,7 +103,7 @@ def _get_hook(self):
f"Connection type ({conn.conn_type}) is not supported by SqlSensor. "
+ f"Supported connection types: {list(allowed_conn_type)}"
)
return conn.get_hook()
return conn.get_hook(hook_kwargs=self.hook_params)

def poke(self, context):
hook = self._get_hook()
Expand Down
12 changes: 12 additions & 0 deletions tests/sensors/test_sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ def test_sql_sensor_presto(self):
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

def test_sql_sensor_hook_params(self):
op = SqlSensor(
task_id='sql_sensor_hook_params',
conn_id='google_cloud_default',
sql="SELECT 1",
hook_params={
'delegate_to': 'me',
},
)
hook = op._get_hook()
assert hook.delegate_to == 'me'