Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, Any, Iterable
from uuid import UUID

from pendulum import duration

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
Expand All @@ -35,6 +39,14 @@
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}


@dataclass
class QueryExecutionOutput:
"""Describes the output of a query execution."""

statement_id: str
session_id: str | None


class RedshiftDataQueryFailedError(ValueError):
"""Raise an error that redshift data query failed."""

Expand Down Expand Up @@ -65,7 +77,7 @@ def __init__(self, *args, **kwargs) -> None:

def execute_query(
self,
database: str,
database: str | None,
sql: str | list[str],
cluster_identifier: str | None = None,
db_user: str | None = None,
Expand All @@ -76,7 +88,9 @@ def execute_query(
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
) -> str:
session_id: str | None = None,
session_keep_alive_seconds: int | None = None,
) -> QueryExecutionOutput:
Comment thread
vincbeck marked this conversation as resolved.
"""
Execute a statement against Amazon Redshift.

Expand All @@ -87,12 +101,15 @@ def execute_query(
:param parameters: the parameters for the SQL statement
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param with_event: whether to send an event to EventBridge
:param wait_for_completion: whether to wait for a result
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param session_id: the session identifier of the query
:param session_keep_alive_seconds: duration in seconds to keep the session alive after the query
finishes. The maximum time a session can keep alive is 24 hours

:returns statement_id: str, the UUID of the statement
"""
Expand All @@ -105,7 +122,28 @@ def execute_query(
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
"SessionId": session_id,
"SessionKeepAliveSeconds": session_keep_alive_seconds,
}

if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1:
raise ValueError(
"Exactly one of cluster_identifier, workgroup_name, or session_id must be provided"
)

if session_id is not None:
msg = "session_id must be a valid UUID4"
try:
if UUID(session_id).version != 4:
raise ValueError(msg)
except ValueError:
raise ValueError(msg)

if session_keep_alive_seconds is not None and (
session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24
):
raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.")

if isinstance(sql, list):
kwargs["Sqls"] = sql
resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
Expand All @@ -115,13 +153,10 @@ def execute_query(

statement_id = resp["Id"]

if bool(cluster_identifier) is bool(workgroup_name):
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")

if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)

return statement_id
return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId"))

def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
Expand All @@ -135,9 +170,9 @@ def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
def check_query_is_finished(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)
return self.parse_statement_response(resp)

def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool:
def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool:
"""Parse the response of describe_statement."""
status = resp["Status"]
if status == FINISHED_STATE:
Expand Down Expand Up @@ -179,8 +214,10 @@ def get_table_primary_key(
:param table: Name of the target table
:param database: the name of the database
:param schema: Name of the target schema, public by default
:param sql: the SQL statement or list of SQL statement to run
:param cluster_identifier: unique identifier of a cluster
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param db_user: the database username
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
Expand Down Expand Up @@ -212,7 +249,8 @@ def get_table_primary_key(
with_event=with_event,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
)
).statement_id

pk_columns = []
token = ""
while True:
Expand Down Expand Up @@ -251,4 +289,4 @@ async def check_query_is_finished_async(self, statement_id: str) -> bool:
"""
async with self.async_conn as client:
resp = await client.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)
return self.parse_statement_response(resp)
18 changes: 15 additions & 3 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
:param verify: Whether to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
Expand All @@ -77,15 +77,16 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
"parameters",
"statement_name",
"workgroup_name",
"session_id",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
statement_id: str | None

def __init__(
self,
database: str,
sql: str | list,
database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: list | None = None,
Expand All @@ -97,6 +98,8 @@ def __init__(
return_sql_result: bool = False,
workgroup_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
session_id: str | None = None,
session_keep_alive_seconds: int | None = None,
Comment thread
vincbeck marked this conversation as resolved.
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -120,6 +123,8 @@ def __init__(
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable
self.session_id = session_id
self.session_keep_alive_seconds = session_keep_alive_seconds

def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
"""Execute a statement against Amazon Redshift."""
Expand All @@ -130,7 +135,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
if self.deferrable:
wait_for_completion = False

self.statement_id = self.hook.execute_query(
query_execution_output = self.hook.execute_query(
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
Expand All @@ -142,8 +147,15 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
with_event=self.with_event,
wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
session_id=self.session_id,
session_keep_alive_seconds=self.session_keep_alive_seconds,
)

self.statement_id = query_execution_output.statement_id

if query_execution_output.session_id:
self.xcom_push(context, key="session_id", value=query_execution_output.session_id)

if self.deferrable and self.wait_for_completion:
is_finished = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/utils/openlineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_facets_from_redshift_table(
]
)
else:
statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, **redshift_data_api_kwargs)
statement_id = redshift_hook.execute_query(
sql=sql, poll_interval=1, **redshift_data_api_kwargs
).statement_id
response = redshift_hook.conn.get_statement_result(Id=statement_id)

table_schema = SchemaDatasetFacet(
Expand Down
127 changes: 122 additions & 5 deletions tests/providers/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
from unittest import mock
from uuid import uuid4

import pytest

Expand Down Expand Up @@ -63,15 +64,18 @@ def test_execute_without_waiting(self, mock_conn):
mock_conn.describe_statement.assert_not_called()

@pytest.mark.parametrize(
"cluster_identifier, workgroup_name",
"cluster_identifier, workgroup_name, session_id",
[
(None, None),
("some_cluster", "some_workgroup"),
(None, None, None),
("some_cluster", "some_workgroup", None),
(None, "some_workgroup", None),
("some_cluster", None, None),
(None, None, "some_session_id"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_requires_either_cluster_identifier_or_workgroup_name(
self, mock_conn, cluster_identifier, workgroup_name
def test_execute_requires_one_of_cluster_identifier_or_workgroup_name_or_session_id(
self, mock_conn, cluster_identifier, workgroup_name, session_id
):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
Expand All @@ -84,6 +88,51 @@ def test_execute_requires_either_cluster_identifier_or_workgroup_name(
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
session_id=session_id,
)

@pytest.mark.parametrize(
"cluster_identifier, workgroup_name, session_id",
[
(None, None, None),
("some_cluster", "some_workgroup", None),
(None, "some_workgroup", None),
("some_cluster", None, None),
(None, None, "some_session_id"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_session_keep_alive_seconds_valid(
self, mock_conn, cluster_identifier, workgroup_name, session_id
):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
workgroup_name = "workgroup_name"
hook = RedshiftDataHook()
with pytest.raises(ValueError):
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
session_id=session_id,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_session_id_valid(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
workgroup_name = "workgroup_name"
hook = RedshiftDataHook()
with pytest.raises(ValueError):
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
session_id="not_a_uuid",
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
Expand Down Expand Up @@ -156,6 +205,74 @@ def test_execute_with_all_parameters_workgroup_name(self, mock_conn):
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_new_session(self, mock_conn):
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": "session_id"}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}

hook = RedshiftDataHook()
output = hook.execute_query(
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
session_keep_alive_seconds=123,
)
assert output.statement_id == STATEMENT_ID
assert output.session_id == "session_id"

mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
Sql=SQL,
ClusterIdentifier=cluster_identifier,
DbUser=db_user,
SecretArn=secret_arn,
StatementName=statement_name,
Parameters=parameters,
WithEvent=False,
SessionKeepAliveSeconds=123,
)
mock_conn.describe_statement.assert_called_once_with(
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_reuse_session(self, mock_conn):
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": "session_id"}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
hook = RedshiftDataHook()
session_id = str(uuid4())
output = hook.execute_query(
database=None,
sql=SQL,
statement_name=statement_name,
parameters=parameters,
session_id=session_id,
)
assert output.statement_id == STATEMENT_ID
assert output.session_id == "session_id"

mock_conn.execute_statement.assert_called_once_with(
Sql=SQL,
StatementName=statement_name,
Parameters=parameters,
WithEvent=False,
SessionId=session_id,
)
mock_conn.describe_statement.assert_called_once_with(
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
Expand Down
Loading