diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index e9b1f23b6dbbf..340e9577efef4 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -176,8 +176,8 @@ def execute(self, context: 'Context'): params["DBName"] = self.db_name if self.cluster_type: params["ClusterType"] = self.cluster_type - if self.number_of_nodes: - params["NumberOfNodes"] = self.number_of_nodes + if self.cluster_type == "multi-node": + params["NumberOfNodes"] = self.number_of_nodes if self.cluster_security_groups: params["ClusterSecurityGroups"] = self.cluster_security_groups if self.vpc_security_group_ids: diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index e29b397211b95..2ec6897687364 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -41,7 +41,7 @@ def test_init(self): assert redshift_operator.master_user_password == "Test123$" @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") - def test_create_cluster(self, mock_get_conn): + def test_create_single_node_cluster(self, mock_get_conn): redshift_operator = RedshiftCreateClusterOperator( task_id="task_test", cluster_identifier="test-cluster", @@ -54,7 +54,36 @@ def test_create_cluster(self, mock_get_conn): params = { "DBName": "dev", "ClusterType": "single-node", - "NumberOfNodes": 1, + "AutomatedSnapshotRetentionPeriod": 1, + "ClusterVersion": "1.0", + "AllowVersionUpgrade": True, + "PubliclyAccessible": True, + "Port": 5439, + } + mock_get_conn.return_value.create_cluster.assert_called_once_with( + ClusterIdentifier='test-cluster', + NodeType="dc2.large", + MasterUsername="adminuser", + MasterUserPassword="Test123$", + **params, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") + def test_create_multi_node_cluster(self, mock_get_conn): + redshift_operator = RedshiftCreateClusterOperator( + task_id="task_test", + cluster_identifier="test-cluster", + node_type="dc2.large", + number_of_nodes=3, + master_username="adminuser", + master_user_password="Test123$", + cluster_type="multi-node", + ) + redshift_operator.execute(None) + params = { + "DBName": "dev", + "ClusterType": "multi-node", + "NumberOfNodes": 3, "AutomatedSnapshotRetentionPeriod": 1, "ClusterVersion": "1.0", "AllowVersionUpgrade": True,