diff --git a/cosmos/profiles/postgres/user_pass.py b/cosmos/profiles/postgres/user_pass.py index 59188d2675..731d600794 100644 --- a/cosmos/profiles/postgres/user_pass.py +++ b/cosmos/profiles/postgres/user_pass.py @@ -40,8 +40,8 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Gets profile. The password is stored in an environment variable." profile = { - **self.mapped_params, "port": 5432, + **self.mapped_params, **self.profile_args, # password should always get set as env var "password": self.get_env_var_format("password"), @@ -55,6 +55,6 @@ def mock_profile(self) -> dict[str, Any | None]: parent_mock = super().mock_profile return { - **parent_mock, "port": 5432, + **parent_mock, } diff --git a/cosmos/profiles/redshift/user_pass.py b/cosmos/profiles/redshift/user_pass.py index 2a4676b079..0dd3e115f4 100644 --- a/cosmos/profiles/redshift/user_pass.py +++ b/cosmos/profiles/redshift/user_pass.py @@ -41,8 +41,8 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Gets profile." profile = { - **self.mapped_params, "port": 5439, + **self.mapped_params, **self.profile_args, # password should always get set as env var "password": self.get_env_var_format("password"), @@ -56,6 +56,6 @@ def mock_profile(self) -> dict[str, Any | None]: parent_mock = super().mock_profile return { - **parent_mock, "port": 5439, + **parent_mock, } diff --git a/tests/profiles/postgres/test_pg_user_pass.py b/tests/profiles/postgres/test_pg_user_pass.py index db1a457014..3492450b5e 100644 --- a/tests/profiles/postgres/test_pg_user_pass.py +++ b/tests/profiles/postgres/test_pg_user_pass.py @@ -30,6 +30,25 @@ def mock_postgres_conn(): # type: ignore yield conn +@pytest.fixture() +def mock_postgres_conn_custom_port(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_postgres_connection", + conn_type="postgres", + host="my_host", + login="my_user", + password="my_password", + port=7472, + schema="my_database", + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + def test_connection_claiming() -> None: """ Tests that the postgres profile mapping claims the correct connection type. @@ -89,6 +108,11 @@ def test_profile_mapping_selected( assert isinstance(profile_mapping, PostgresUserPasswordProfileMapping) +def test_profile_mapping_keeps_custom_port(mock_postgres_conn_custom_port: Connection) -> None: + profile = PostgresUserPasswordProfileMapping(mock_postgres_conn_custom_port.conn_id, {"schema": "my_schema"}) + assert profile.profile["port"] == 7472 + + def test_profile_args( mock_postgres_conn: Connection, ) -> None: