diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 126da03ad630f..7596ad3886c7a 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -26,6 +26,33 @@ Changelog --------- +Main +...... + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + In order to support session reuse in RedshiftData operators, the following breaking changes were introduced: + + The ``database`` argument is now optional and as a result was moved after the ``sql`` argument which is a positional + one. Update your DAGs accordingly if they rely on argument order. Applies to: + * ``RedshiftDataHook``'s ``execute_query`` method + * ``RedshiftDataOperator`` + + ``RedshiftDataHook``'s ``execute_query`` method now returns a ``QueryExecutionOutput`` object instead of just the + statement ID as a string. + + ``RedshiftDataHook``'s ``parse_statement_resposne`` method was renamed to ``parse_statement_response``. + + ``S3ToRedshiftOperator``'s ``schema`` argument is now optional and was moved after the ``s3_key`` positional argument. + Update your DAGs accordingly if they rely on argument order. + +Features +~~~~~~~~ + +* ``Support session reuse in RedshiftDataOperator, RedshiftToS3Operator and S3ToRedshiftOperator (#42218)`` + 8.29.0 ...... diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 3c1f84b1f694c..b2f46c0ef6049 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -18,8 +18,12 @@ from __future__ import annotations import time +from dataclasses import dataclass from pprint import pformat from typing import TYPE_CHECKING, Any, Iterable +from uuid import UUID + +from pendulum import duration from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils import trim_none_values @@ -35,6 +39,14 @@ RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"} +@dataclass +class QueryExecutionOutput: + """Describes the output of a query execution.""" + + statement_id: str + session_id: str | None + + class RedshiftDataQueryFailedError(ValueError): """Raise an error that redshift data query failed.""" @@ -65,8 +77,8 @@ def __init__(self, *args, **kwargs) -> None: def execute_query( self, - database: str, sql: str | list[str], + database: str | None = None, cluster_identifier: str | None = None, db_user: str | None = None, parameters: Iterable | None = None, @@ -76,23 +88,28 @@ def execute_query( wait_for_completion: bool = True, poll_interval: int = 10, workgroup_name: str | None = None, - ) -> str: + session_id: str | None = None, + session_keep_alive_seconds: int | None = None, + ) -> QueryExecutionOutput: """ Execute a statement against Amazon Redshift. - :param database: the name of the database :param sql: the SQL statement or list of SQL statement to run + :param database: the name of the database :param cluster_identifier: unique identifier of a cluster :param db_user: the database username :param parameters: the parameters for the SQL statement :param secret_arn: the name or ARN of the secret that enables db access :param statement_name: the name of the SQL statement - :param with_event: indicates whether to send an event to EventBridge - :param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait + :param with_event: whether to send an event to EventBridge + :param wait_for_completion: whether to wait for a result :param poll_interval: how often in seconds to check the query status :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html + :param session_id: the session identifier of the query + :param session_keep_alive_seconds: duration in seconds to keep the session alive after the query + finishes. The maximum time a session can keep alive is 24 hours :returns statement_id: str, the UUID of the statement """ @@ -105,7 +122,28 @@ def execute_query( "SecretArn": secret_arn, "StatementName": statement_name, "WorkgroupName": workgroup_name, + "SessionId": session_id, + "SessionKeepAliveSeconds": session_keep_alive_seconds, } + + if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1: + raise ValueError( + "Exactly one of cluster_identifier, workgroup_name, or session_id must be provided" + ) + + if session_id is not None: + msg = "session_id must be a valid UUID4" + try: + if UUID(session_id).version != 4: + raise ValueError(msg) + except ValueError: + raise ValueError(msg) + + if session_keep_alive_seconds is not None and ( + session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24 + ): + raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.") + if isinstance(sql, list): kwargs["Sqls"] = sql resp = self.conn.batch_execute_statement(**trim_none_values(kwargs)) @@ -115,13 +153,10 @@ def execute_query( statement_id = resp["Id"] - if bool(cluster_identifier) is bool(workgroup_name): - raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.") - if wait_for_completion: self.wait_for_results(statement_id, poll_interval=poll_interval) - return statement_id + return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId")) def wait_for_results(self, statement_id: str, poll_interval: int) -> str: while True: @@ -135,9 +170,9 @@ def wait_for_results(self, statement_id: str, poll_interval: int) -> str: def check_query_is_finished(self, statement_id: str) -> bool: """Check whether query finished, raise exception is failed.""" resp = self.conn.describe_statement(Id=statement_id) - return self.parse_statement_resposne(resp) + return self.parse_statement_response(resp) - def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool: + def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool: """Parse the response of describe_statement.""" status = resp["Status"] if status == FINISHED_STATE: @@ -179,8 +214,10 @@ def get_table_primary_key( :param table: Name of the target table :param database: the name of the database :param schema: Name of the target schema, public by default - :param sql: the SQL statement or list of SQL statement to run :param cluster_identifier: unique identifier of a cluster + :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with + `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info + https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html :param db_user: the database username :param secret_arn: the name or ARN of the secret that enables db access :param statement_name: the name of the SQL statement @@ -212,7 +249,8 @@ def get_table_primary_key( with_event=with_event, wait_for_completion=wait_for_completion, poll_interval=poll_interval, - ) + ).statement_id + pk_columns = [] token = "" while True: @@ -251,4 +289,4 @@ async def check_query_is_finished_async(self, statement_id: str) -> bool: """ async with self.async_conn as client: resp = await client.describe_statement(Id=statement_id) - return self.parse_statement_resposne(resp) + return self.parse_statement_response(resp) diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index 45fee2a919483..3d00c6d22edf7 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -56,13 +56,16 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html + :param session_id: the session identifier of the query + :param session_keep_alive_seconds: duration in seconds to keep the session alive after the query + finishes. The maximum time a session can keep alive is 24 hours :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. - :param verify: Whether or not to verify SSL certificates. See: + :param verify: Whether to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html @@ -77,6 +80,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): "parameters", "statement_name", "workgroup_name", + "session_id", ) template_ext = (".sql",) template_fields_renderers = {"sql": "sql"} @@ -84,8 +88,8 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): def __init__( self, - database: str, sql: str | list, + database: str | None = None, cluster_identifier: str | None = None, db_user: str | None = None, parameters: list | None = None, @@ -97,6 +101,8 @@ def __init__( return_sql_result: bool = False, workgroup_name: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + session_id: str | None = None, + session_keep_alive_seconds: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -120,6 +126,8 @@ def __init__( self.return_sql_result = return_sql_result self.statement_id: str | None = None self.deferrable = deferrable + self.session_id = session_id + self.session_keep_alive_seconds = session_keep_alive_seconds def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: """Execute a statement against Amazon Redshift.""" @@ -130,7 +138,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: if self.deferrable: wait_for_completion = False - self.statement_id = self.hook.execute_query( + query_execution_output = self.hook.execute_query( database=self.database, sql=self.sql, cluster_identifier=self.cluster_identifier, @@ -142,8 +150,15 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: with_event=self.with_event, wait_for_completion=wait_for_completion, poll_interval=self.poll_interval, + session_id=self.session_id, + session_keep_alive_seconds=self.session_keep_alive_seconds, ) + self.statement_id = query_execution_output.statement_id + + if query_execution_output.session_id: + self.xcom_push(context, key="session_id", value=query_execution_output.session_id) + if self.deferrable and self.wait_for_completion: is_finished = self.hook.check_query_is_finished(self.statement_id) if not is_finished: diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index ef3cebdae9838..8538b1dfc313c 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -45,7 +45,8 @@ class RedshiftToS3Operator(BaseOperator): :param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set to False, this param must include the desired file name :param schema: reference to a specific schema in redshift database, - used when ``table`` param provided and ``select_query`` param not provided + used when ``table`` param provided and ``select_query`` param not provided. + Do not provide when unloading a temporary table :param table: reference to a specific table in redshift database, used when ``schema`` param provided and ``select_query`` param not provided :param select_query: custom select query to fetch data from redshift database, @@ -55,8 +56,8 @@ class RedshiftToS3Operator(BaseOperator): If the AWS connection contains 'aws_iam_role' in ``extras`` the operator will use AWS STS credentials with a token https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. + :param verify: Whether to verify SSL certificates for S3 connection. + By default, SSL certificates are verified. You can provide the following values: - ``False``: do not validate SSL certificates. SSL will still be used @@ -67,7 +68,7 @@ class RedshiftToS3Operator(BaseOperator): CA cert bundle than the one used by botocore. :param unload_options: reference to a list of UNLOAD options :param autocommit: If set to True it will automatically commit the UNLOAD statement. - Otherwise it will be committed right before the redshift connection gets closed. + Otherwise, it will be committed right before the redshift connection gets closed. :param include_header: If set to True the s3 file contains the header columns. :param parameters: (optional) the parameters to render the SQL query with. :param table_as_file_name: If set to True, the s3 file will be named as the table. @@ -141,9 +142,15 @@ def _build_unload_query( @property def default_select_query(self) -> str | None: - if self.schema and self.table: - return f"SELECT * FROM {self.schema}.{self.table}" - return None + if not self.table: + return None + + if self.schema: + table = f"{self.schema}.{self.table}" + else: + # Relevant when unloading a temporary table + table = self.table + return f"SELECT * FROM {table}" def execute(self, context: Context) -> None: if self.table and self.table_as_file_name: @@ -152,9 +159,7 @@ def execute(self, context: Context) -> None: self.select_query = self.select_query or self.default_select_query if self.select_query is None: - raise ValueError( - "Please provide both `schema` and `table` params or `select_query` to fetch the data." - ) + raise ValueError("Please specify either a table or `select_query` to fetch the data.") if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]: self.unload_options = [*self.unload_options, "HEADER"] diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 127ee07a60bbd..792119bfebb55 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -28,7 +28,6 @@ if TYPE_CHECKING: from airflow.utils.context import Context - AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"] @@ -40,17 +39,18 @@ class S3ToRedshiftOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:S3ToRedshiftOperator` - :param schema: reference to a specific schema in redshift database :param table: reference to a specific table in redshift database :param s3_bucket: reference to a specific S3 bucket :param s3_key: key prefix that selects single or multiple objects from S3 + :param schema: reference to a specific schema in redshift database. + Do not provide when copying into a temporary table :param redshift_conn_id: reference to a specific redshift database OR a redshift data-api connection :param aws_conn_id: reference to a specific S3 connection If the AWS connection contains 'aws_iam_role' in ``extras`` the operator will use AWS STS credentials with a token https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. + :param verify: Whether to verify SSL certificates for S3 connection. + By default, SSL certificates are verified. You can provide the following values: - ``False``: do not validate SSL certificates. SSL will still be used @@ -87,10 +87,10 @@ class S3ToRedshiftOperator(BaseOperator): def __init__( self, *, - schema: str, table: str, s3_bucket: str, s3_key: str, + schema: str | None = None, redshift_conn_id: str = "redshift_default", aws_conn_id: str | None = "aws_default", verify: bool | str | None = None, @@ -160,7 +160,7 @@ def execute(self, context: Context) -> None: credentials_block = build_credentials_block(credentials) copy_options = "\n\t\t\t".join(self.copy_options) - destination = f"{self.schema}.{self.table}" + destination = f"{self.schema}.{self.table}" if self.schema else self.table copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination copy_statement = self._build_copy_query( diff --git a/airflow/providers/amazon/aws/utils/openlineage.py b/airflow/providers/amazon/aws/utils/openlineage.py index db472a3e46c5f..be5703e2f6e80 100644 --- a/airflow/providers/amazon/aws/utils/openlineage.py +++ b/airflow/providers/amazon/aws/utils/openlineage.py @@ -86,7 +86,9 @@ def get_facets_from_redshift_table( ] ) else: - statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, **redshift_data_api_kwargs) + statement_id = redshift_hook.execute_query( + sql=sql, poll_interval=1, **redshift_data_api_kwargs + ).statement_id response = redshift_hook.conn.get_statement_result(Id=statement_id) table_schema = SchemaDatasetFacet( diff --git a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst index 0b314d34f3193..2638e1732cd6c 100644 --- a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst +++ b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst @@ -54,6 +54,18 @@ the necessity of a Postgres connection. :start-after: [START howto_operator_redshift_data] :end-before: [END howto_operator_redshift_data] +Reuse a session when executing multiple statements +================================================== + +Specify the ``session_keep_alive_seconds`` parameter on an upstream task. In a downstream task, get the session ID from +the XCom and pass it to the ``session_id`` parameter. This is useful when you work with temporary tables. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_redshift_data_session_reuse] + :end-before: [END howto_operator_redshift_data_session_reuse] + Reference --------- diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index a0952e5ba7259..d548086449812 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -19,6 +19,7 @@ import logging from unittest import mock +from uuid import uuid4 import pytest @@ -63,15 +64,18 @@ def test_execute_without_waiting(self, mock_conn): mock_conn.describe_statement.assert_not_called() @pytest.mark.parametrize( - "cluster_identifier, workgroup_name", + "cluster_identifier, workgroup_name, session_id", [ - (None, None), - ("some_cluster", "some_workgroup"), + (None, None, None), + ("some_cluster", "some_workgroup", None), + (None, "some_workgroup", None), + ("some_cluster", None, None), + (None, None, "some_session_id"), ], ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - def test_execute_requires_either_cluster_identifier_or_workgroup_name( - self, mock_conn, cluster_identifier, workgroup_name + def test_execute_requires_one_of_cluster_identifier_or_workgroup_name_or_session_id( + self, mock_conn, cluster_identifier, workgroup_name, session_id ): mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} cluster_identifier = "cluster_identifier" @@ -84,6 +88,51 @@ def test_execute_requires_either_cluster_identifier_or_workgroup_name( workgroup_name=workgroup_name, sql=SQL, wait_for_completion=False, + session_id=session_id, + ) + + @pytest.mark.parametrize( + "cluster_identifier, workgroup_name, session_id", + [ + (None, None, None), + ("some_cluster", "some_workgroup", None), + (None, "some_workgroup", None), + ("some_cluster", None, None), + (None, None, "some_session_id"), + ], + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_session_keep_alive_seconds_valid( + self, mock_conn, cluster_identifier, workgroup_name, session_id + ): + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + cluster_identifier = "cluster_identifier" + workgroup_name = "workgroup_name" + hook = RedshiftDataHook() + with pytest.raises(ValueError): + hook.execute_query( + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + sql=SQL, + wait_for_completion=False, + session_id=session_id, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_session_id_valid(self, mock_conn): + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + cluster_identifier = "cluster_identifier" + workgroup_name = "workgroup_name" + hook = RedshiftDataHook() + with pytest.raises(ValueError): + hook.execute_query( + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + sql=SQL, + wait_for_completion=False, + session_id="not_a_uuid", ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") @@ -156,6 +205,74 @@ def test_execute_with_all_parameters_workgroup_name(self, mock_conn): Id=STATEMENT_ID, ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_with_new_session(self, mock_conn): + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": "session_id"} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + + hook = RedshiftDataHook() + output = hook.execute_query( + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + session_keep_alive_seconds=123, + ) + assert output.statement_id == STATEMENT_ID + assert output.session_id == "session_id" + + mock_conn.execute_statement.assert_called_once_with( + Database=DATABASE, + Sql=SQL, + ClusterIdentifier=cluster_identifier, + DbUser=db_user, + SecretArn=secret_arn, + StatementName=statement_name, + Parameters=parameters, + WithEvent=False, + SessionKeepAliveSeconds=123, + ) + mock_conn.describe_statement.assert_called_once_with( + Id=STATEMENT_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_reuse_session(self, mock_conn): + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": "session_id"} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + hook = RedshiftDataHook() + session_id = str(uuid4()) + output = hook.execute_query( + database=None, + sql=SQL, + statement_name=statement_name, + parameters=parameters, + session_id=session_id, + ) + assert output.statement_id == STATEMENT_ID + assert output.session_id == "session_id" + + mock_conn.execute_statement.assert_called_once_with( + Sql=SQL, + StatementName=statement_name, + Parameters=parameters, + WithEvent=False, + SessionId=session_id, + ) + mock_conn.describe_statement.assert_called_once_with( + Id=STATEMENT_ID, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_batch_execute(self, mock_conn): mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index abfa2b038b98b..c22d776a94b44 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -22,6 +22,7 @@ import pytest from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred +from airflow.providers.amazon.aws.hooks.redshift_data import QueryExecutionOutput from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields @@ -31,6 +32,7 @@ SQL = "sql" DATABASE = "database" STATEMENT_ID = "statement_id" +SESSION_ID = "session_id" @pytest.fixture @@ -98,6 +100,8 @@ def test_execute(self, mock_exec_query): poll_interval = 5 wait_for_completion = True + mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None) + operator = RedshiftDataOperator( aws_conn_id=CONN_ID, task_id=TASK_ID, @@ -111,7 +115,8 @@ def test_execute(self, mock_exec_query): wait_for_completion=True, poll_interval=poll_interval, ) - operator.execute(None) + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) mock_exec_query.assert_called_once_with( sql=SQL, database=DATABASE, @@ -124,8 +129,12 @@ def test_execute(self, mock_exec_query): with_event=False, wait_for_completion=wait_for_completion, poll_interval=poll_interval, + session_id=None, + session_keep_alive_seconds=None, ) + mock_ti.xcom_push.assert_not_called() + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") def test_execute_with_workgroup_name(self, mock_exec_query): cluster_identifier = None @@ -150,7 +159,54 @@ def test_execute_with_workgroup_name(self, mock_exec_query): wait_for_completion=True, poll_interval=poll_interval, ) - operator.execute(None) + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) + mock_exec_query.assert_called_once_with( + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + with_event=False, + wait_for_completion=wait_for_completion, + poll_interval=poll_interval, + session_id=None, + session_keep_alive_seconds=None, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_new_session(self, mock_exec_query): + cluster_identifier = "cluster_identifier" + workgroup_name = None + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + wait_for_completion = True + + mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID) + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=True, + poll_interval=poll_interval, + session_keep_alive_seconds=123, + ) + + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) mock_exec_query.assert_called_once_with( sql=SQL, database=DATABASE, @@ -163,7 +219,11 @@ def test_execute_with_workgroup_name(self, mock_exec_query): with_event=False, wait_for_completion=wait_for_completion, poll_interval=poll_interval, + session_id=None, + session_keep_alive_seconds=123, ) + assert mock_ti.xcom_push.call_args.kwargs["key"] == "session_id" + assert mock_ti.xcom_push.call_args.kwargs["value"] == SESSION_ID @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_on_kill_without_query(self, mock_conn): @@ -180,7 +240,7 @@ def test_on_kill_without_query(self, mock_conn): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_on_kill_with_query(self, mock_conn): - mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID} operator = RedshiftDataOperator( aws_conn_id=CONN_ID, task_id=TASK_ID, @@ -189,7 +249,8 @@ def test_on_kill_with_query(self, mock_conn): database=DATABASE, wait_for_completion=False, ) - operator.execute(None) + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) operator.on_kill() mock_conn.cancel_statement.assert_called_once_with( Id=STATEMENT_ID, @@ -198,7 +259,7 @@ def test_on_kill_with_query(self, mock_conn): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_return_sql_result(self, mock_conn): expected_result = {"Result": True} - mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID} mock_conn.describe_statement.return_value = {"Status": "FINISHED"} mock_conn.get_statement_result.return_value = expected_result cluster_identifier = "cluster_identifier" @@ -216,7 +277,8 @@ def test_return_sql_result(self, mock_conn): aws_conn_id=CONN_ID, return_sql_result=True, ) - actual_result = operator.execute(None) + mock_ti = mock.MagicMock(name="MockedTaskInstance") + actual_result = operator.execute({"ti": mock_ti}) assert actual_result == expected_result mock_conn.execute_statement.assert_called_once_with( Database=DATABASE, @@ -260,7 +322,9 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin poll_interval=poll_interval, deferrable=True, ) - operator.execute(None) + + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) assert not mock_defer.called mock_exec_query.assert_called_once_with( @@ -275,6 +339,8 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin with_event=False, wait_for_completion=False, poll_interval=poll_interval, + session_id=None, + session_keep_alive_seconds=None, ) @mock.patch( @@ -283,8 +349,9 @@ def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_fin ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") def test_execute_defer(self, mock_exec_query, check_query_is_finished, deferrable_operator): + mock_ti = mock.MagicMock(name="MockedTaskInstance") with pytest.raises(TaskDeferred) as exc: - deferrable_operator.execute(None) + deferrable_operator.execute({"ti": mock_ti}) assert isinstance(exc.value.trigger, RedshiftDataTrigger) @@ -346,7 +413,8 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis poll_interval=poll_interval, deferrable=deferrable, ) - operator.execute(None) + mock_ti = mock.MagicMock(name="MockedTaskInstance") + operator.execute({"ti": mock_ti}) assert not mock_check_query_is_finished.called assert not mock_defer.called diff --git a/tests/providers/amazon/aws/utils/test_openlineage.py b/tests/providers/amazon/aws/utils/test_openlineage.py index b3e820b58185e..195db068d3092 100644 --- a/tests/providers/amazon/aws/utils/test_openlineage.py +++ b/tests/providers/amazon/aws/utils/test_openlineage.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.hooks.redshift_data import QueryExecutionOutput, RedshiftDataHook from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook from airflow.providers.amazon.aws.utils.openlineage import ( get_facets_from_redshift_table, @@ -58,7 +58,7 @@ def test_get_facets_from_redshift_table_sql_hook(mock_get_records): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_get_facets_from_redshift_table_data_hook(mock_connection, mock_execute_query): - mock_execute_query.return_value = "statement_id" + mock_execute_query.return_value = QueryExecutionOutput(statement_id="statement_id", session_id=None) mock_connection.get_statement_result.return_value = { "Records": [ [ diff --git a/tests/system/providers/amazon/aws/example_redshift.py b/tests/system/providers/amazon/aws/example_redshift.py index cc92076dcba0a..67b822d41ef55 100644 --- a/tests/system/providers/amazon/aws/example_redshift.py +++ b/tests/system/providers/amazon/aws/example_redshift.py @@ -50,7 +50,6 @@ DB_NAME = "dev" POLL_INTERVAL = 10 - with DAG( dag_id=DAG_ID, start_date=datetime(2021, 1, 1), @@ -175,6 +174,37 @@ wait_for_completion=True, ) + # [START howto_operator_redshift_data_session_reuse] + create_tmp_table_data_api = RedshiftDataOperator( + task_id="create_tmp_table_data_api", + cluster_identifier=redshift_cluster_identifier, + database=DB_NAME, + db_user=DB_LOGIN, + sql=""" + CREATE TEMPORARY TABLE tmp_people ( + id INTEGER, + first_name VARCHAR(100), + age INTEGER + ); + """, + poll_interval=POLL_INTERVAL, + wait_for_completion=True, + session_keep_alive_seconds=600, + ) + + insert_data_reuse_session = RedshiftDataOperator( + task_id="insert_data_reuse_session", + sql=""" + INSERT INTO tmp_people VALUES ( 1, 'Bob', 30); + INSERT INTO tmp_people VALUES ( 2, 'Alice', 35); + INSERT INTO tmp_people VALUES ( 3, 'Charlie', 40); + """, + poll_interval=POLL_INTERVAL, + wait_for_completion=True, + session_id="{{ task_instance.xcom_pull(task_ids='create_tmp_table_data_api', key='session_id') }}", + ) + # [END howto_operator_redshift_data_session_reuse] + # [START howto_operator_redshift_delete_cluster] delete_cluster = RedshiftDeleteClusterOperator( task_id="delete_cluster", @@ -209,13 +239,20 @@ delete_cluster, ) + # Test session reuse in parallel + chain( + wait_cluster_available_after_resume, + create_tmp_table_data_api, + insert_data_reuse_session, + delete_cluster_snapshot, + ) + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() - from tests.system.utils import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) diff --git a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py index 0691046190507..9fb989ec53697 100644 --- a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py +++ b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py @@ -53,22 +53,34 @@ S3_KEY = "s3_output_" S3_KEY_2 = "s3_key_2" +S3_KEY_3 = "s3_output_tmp_table_" S3_KEY_PREFIX = "s3_k" REDSHIFT_TABLE = "test_table" +REDSHIFT_TMP_TABLE = "tmp_table" -SQL_CREATE_TABLE = f""" - CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE} ( - fruit_id INTEGER, - name VARCHAR NOT NULL, - color VARCHAR NOT NULL - ); -""" +DATA = "0, 'Airflow', 'testing'" -SQL_INSERT_DATA = f"INSERT INTO {REDSHIFT_TABLE} VALUES ( 1, 'Banana', 'Yellow');" -SQL_DROP_TABLE = f"DROP TABLE IF EXISTS {REDSHIFT_TABLE};" +def _drop_table(table_name: str) -> str: + return f"DROP TABLE IF EXISTS {table_name};" -DATA = "0, 'Airflow', 'testing'" + +def _create_table(table_name: str, is_temp: bool = False) -> str: + temp_keyword = "TEMPORARY" if is_temp else "" + return ( + _drop_table(table_name) + + f""" + CREATE {temp_keyword} TABLE {table_name} ( + fruit_id INTEGER, + name VARCHAR NOT NULL, + color VARCHAR NOT NULL + ); + """ + ) + + +def _insert_data(table_name: str) -> str: + return f"INSERT INTO {table_name} VALUES ( 1, 'Banana', 'Yellow');" with DAG( @@ -124,7 +136,7 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=SQL_CREATE_TABLE, + sql=_create_table(REDSHIFT_TABLE), wait_for_completion=True, ) @@ -133,7 +145,7 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=SQL_INSERT_DATA, + sql=_insert_data(REDSHIFT_TABLE), wait_for_completion=True, ) @@ -159,6 +171,33 @@ bucket_key=f"{S3_KEY}/{REDSHIFT_TABLE}_0000_part_00", ) + create_tmp_table = RedshiftDataOperator( + task_id="create_tmp_table", + cluster_identifier=redshift_cluster_identifier, + database=DB_NAME, + db_user=DB_LOGIN, + sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True) + _insert_data(REDSHIFT_TMP_TABLE), + wait_for_completion=True, + session_keep_alive_seconds=600, + ) + + transfer_redshift_to_s3_reuse_session = RedshiftToS3Operator( + task_id="transfer_redshift_to_s3_reuse_session", + redshift_data_api_kwargs={ + "wait_for_completion": True, + "session_id": "{{ task_instance.xcom_pull(task_ids='create_tmp_table', key='session_id') }}", + }, + s3_bucket=bucket_name, + s3_key=S3_KEY_3, + table=REDSHIFT_TMP_TABLE, + ) + + check_if_tmp_table_key_exists = S3KeySensor( + task_id="check_if_tmp_table_key_exists", + bucket_name=bucket_name, + bucket_key=f"{S3_KEY_3}/{REDSHIFT_TMP_TABLE}_0000_part_00", + ) + # [START howto_transfer_s3_to_redshift] transfer_s3_to_redshift = S3ToRedshiftOperator( task_id="transfer_s3_to_redshift", @@ -176,6 +215,28 @@ ) # [END howto_transfer_s3_to_redshift] + create_dest_tmp_table = RedshiftDataOperator( + task_id="create_dest_tmp_table", + cluster_identifier=redshift_cluster_identifier, + database=DB_NAME, + db_user=DB_LOGIN, + sql=_create_table(REDSHIFT_TMP_TABLE, is_temp=True), + wait_for_completion=True, + session_keep_alive_seconds=600, + ) + + transfer_s3_to_redshift_tmp_table = S3ToRedshiftOperator( + task_id="transfer_s3_to_redshift_tmp_table", + redshift_data_api_kwargs={ + "session_id": "{{ task_instance.xcom_pull(task_ids='create_dest_tmp_table', key='session_id') }}", + "wait_for_completion": True, + }, + s3_bucket=bucket_name, + s3_key=S3_KEY_2, + table=REDSHIFT_TMP_TABLE, + copy_options=["csv"], + ) + # [START howto_transfer_s3_to_redshift_multiple_keys] transfer_s3_to_redshift_multiple = S3ToRedshiftOperator( task_id="transfer_s3_to_redshift_multiple", @@ -198,7 +259,7 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=SQL_DROP_TABLE, + sql=_drop_table(REDSHIFT_TABLE), wait_for_completion=True, trigger_rule=TriggerRule.ALL_DONE, ) @@ -235,13 +296,33 @@ delete_bucket, ) + chain( + # TEST SETUP + wait_cluster_available, + create_tmp_table, + # TEST BODY + transfer_redshift_to_s3_reuse_session, + check_if_tmp_table_key_exists, + # TEST TEARDOWN + delete_cluster, + ) + + chain( + # TEST SETUP + wait_cluster_available, + create_dest_tmp_table, + # TEST BODY + transfer_s3_to_redshift_tmp_table, + # TEST TEARDOWN + delete_cluster, + ) + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() - from tests.system.utils import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)