From eea21b38bd208057ac775094dc84748080e33c3d Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sat, 3 May 2025 16:00:23 +0530 Subject: [PATCH 1/8] Add sqlserver profile mapping --- cosmos/profiles/__init__.py | 3 + cosmos/profiles/sqlserver/__init__.py | 0 .../sqlserver/standard_sqlserver_auth.py | 60 +++++++++ .../sqlserver/test_standard_sqlserver_auth.py | 118 ++++++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 cosmos/profiles/sqlserver/__init__.py create mode 100644 cosmos/profiles/sqlserver/standard_sqlserver_auth.py create mode 100644 tests/profiles/sqlserver/test_standard_sqlserver_auth.py diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index d8751899be..c3234bc872 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -22,6 +22,7 @@ from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping +from .sqlserver.standard_sqlserver_auth import StandardSQLServerAuth from .teradata.user_pass import TeradataUserPasswordProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -52,6 +53,7 @@ TrinoCertificateProfileMapping, TrinoJWTProfileMapping, VerticaUserPasswordProfileMapping, + StandardSQLServerAuth, ] @@ -96,4 +98,5 @@ def get_automatic_profile_mapping( "TrinoCertificateProfileMapping", "TrinoJWTProfileMapping", "VerticaUserPasswordProfileMapping", + "StandardSQLServerAuth", ] diff --git a/cosmos/profiles/sqlserver/__init__.py b/cosmos/profiles/sqlserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py new file mode 100644 index 0000000000..e781e90c2b --- /dev/null +++ b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py @@ -0,0 +1,60 @@ +from typing import Any + +from cosmos.profiles import BaseProfileMapping + + +class StandardSQLServerAuth(BaseProfileMapping): + + airflow_connection_type: str = "generic" + dbt_profile_type: str = "sqlserver" + default_port = 1433 + is_community = True + + required_fields = [ + "server", + "user", + "schema", + "database", + "driver", + "password", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "server": "host", + "user": "login", + "password": "password", + "port": "port", + "schema": "schema", + "database": "extra.database", + "driver": "extra.driver", + } + + def _set_default_param(self, profile_dict: dict[str, Any]) -> dict[str, Any]: + + if not profile_dict.get("port"): + profile_dict["port"] = self.default_port + + return profile_dict + + @property + def profile(self) -> dict[str, Any]: + profile_dict = { + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + return self.filter_null(self._set_default_param(profile_dict)) + + @property + def mock_profile(self) -> dict[str, Any]: + """Gets mock profile.""" + + profile_dict = { + **super().mock_profile, + } + + return self._set_default_param(profile_dict) diff --git a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py new file mode 100644 index 0000000000..5035df2e46 --- /dev/null +++ b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py @@ -0,0 +1,118 @@ +"""Tests for the sqlserver profile.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.sqlserver.standard_sqlserver_auth import ( + StandardSQLServerAuth, +) + + +@pytest.fixture() +def mock_sqlserver_conn(): # type: ignore + """Sets the connection as an environment variable.""" + conn = Connection( + conn_id="sqlserver_connection", + conn_type="generic", + host="my_host", + login="my_user", + port=1433, + password="my_password", + schema="dbo", + extra='{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the clickhouse profile mapping claims the correct connection type. + + should only claim when: + - conn_type == generic + And the following exist: + - host + - login + - password + - schema + - extra.databases + - extra.driver + """ + required_values = { + "conn_type": "generic", + "host": "my_host", + "login": "my_user", + "schema": "dbo", + "password": "pass", + "extra": '{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}', + } + + def can_claim_with_missing_key(missing_key: str) -> bool: + values = required_values.copy() + del values[missing_key] + conn = Connection(**values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StandardSQLServerAuth(conn, {}) + return profile_mapping.can_claim_connection() + + # if we're missing any of the required values, it shouldn't claim + for key in required_values: + assert not can_claim_with_missing_key(key), f"Failed when missing {key}" + + # if we have all the required values, it should claim + conn = Connection(**required_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StandardSQLServerAuth(conn, {}) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_sqlserver_conn: Connection, +) -> None: + """Tests that the correct profile mapping is selected.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, {}) + assert isinstance(profile_mapping, StandardSQLServerAuth) + + +def test_profile_args(mock_sqlserver_conn: Connection) -> None: + """Tests that the profile values get set correctly.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={}) + + assert profile_mapping.profile == { + "type": "sqlserver", + "schema": mock_sqlserver_conn.schema, + "user": mock_sqlserver_conn.login, + "password": "{{ env_var('COSMOS_CONN_GENERIC_PASSWORD') }}", + "driver": mock_sqlserver_conn.extra_dejson["driver"], + "port": 1433, + "server": mock_sqlserver_conn.host, + "database": mock_sqlserver_conn.extra_dejson["database"], + } + + +def test_mock_profile() -> None: + """Tests that the mock_profile values get set correctly.""" + profile_mapping = StandardSQLServerAuth( + "conn_id" + ) # get_automatic_profile_mapping("mock_clickhouse_conn.conn_id", profile_args={}) + + assert profile_mapping.mock_profile == { + "type": "sqlserver", + "server": "mock_value", + "schema": "mock_value", + "database": "mock_value", + "user": "mock_value", + "driver": "mock_value", + "port": 1433, + } + + +def test_profile_env_vars(mock_sqlserver_conn: Connection) -> None: + """Tests that the environment variables get set correctly.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={}) + assert profile_mapping.env_vars == {"COSMOS_CONN_GENERIC_PASSWORD": mock_sqlserver_conn.password} From 75afb7611d325c3193c757e1da71f293a4813e7f Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:23:56 +0530 Subject: [PATCH 2/8] Add dbt-sqlserver in pyproject --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9f648ecaa6..23a7809341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dbt-all = [ "dbt-redshift", "dbt-snowflake", "dbt-spark", + "dbt-sqlserver", "dbt-teradata", "dbt-vertica", ] @@ -71,6 +72,7 @@ dbt-postgres = ["dbt-postgres"] dbt-redshift = ["dbt-redshift"] dbt-snowflake = ["dbt-snowflake"] dbt-spark = ["dbt-spark"] +dbt-sqlserver = ["dbt-sqlserver"] dbt-teradata = ["dbt-teradata"] dbt-vertica = ["dbt-vertica<=1.5.4"] openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"] From 6997cd3f74d44c65b9e4f81e72cf310fdda1a884 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:28:55 +0530 Subject: [PATCH 3/8] Fix tests --- tests/profiles/sqlserver/test_standard_sqlserver_auth.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py index 5035df2e46..9b3fabdade 100644 --- a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py +++ b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py @@ -31,7 +31,7 @@ def mock_sqlserver_conn(): # type: ignore def test_connection_claiming() -> None: """ - Tests that the clickhouse profile mapping claims the correct connection type. + Tests that the sqlserver profile mapping claims the correct connection type. should only claim when: - conn_type == generic @@ -97,9 +97,7 @@ def test_profile_args(mock_sqlserver_conn: Connection) -> None: def test_mock_profile() -> None: """Tests that the mock_profile values get set correctly.""" - profile_mapping = StandardSQLServerAuth( - "conn_id" - ) # get_automatic_profile_mapping("mock_clickhouse_conn.conn_id", profile_args={}) + profile_mapping = StandardSQLServerAuth("conn_id") assert profile_mapping.mock_profile == { "type": "sqlserver", @@ -107,6 +105,7 @@ def test_mock_profile() -> None: "schema": "mock_value", "database": "mock_value", "user": "mock_value", + "password": "mock_value", "driver": "mock_value", "port": 1433, } From 38ef12cab520ed0551a7ba85c5000b229e936379 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:29:33 +0530 Subject: [PATCH 4/8] Fix tests --- tests/profiles/sqlserver/test_standard_sqlserver_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py index 9b3fabdade..1ffc3a3751 100644 --- a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py +++ b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py @@ -89,7 +89,7 @@ def test_profile_args(mock_sqlserver_conn: Connection) -> None: "user": mock_sqlserver_conn.login, "password": "{{ env_var('COSMOS_CONN_GENERIC_PASSWORD') }}", "driver": mock_sqlserver_conn.extra_dejson["driver"], - "port": 1433, + "port": mock_sqlserver_conn.port, "server": mock_sqlserver_conn.host, "database": mock_sqlserver_conn.extra_dejson["database"], } From 17950598039824649ac98e4b81fa8652959d7672 Mon Sep 17 00:00:00 2001 From: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Date: Sun, 4 May 2025 11:32:34 +0530 Subject: [PATCH 5/8] Update tests/profiles/sqlserver/test_standard_sqlserver_auth.py --- tests/profiles/sqlserver/test_standard_sqlserver_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py index 1ffc3a3751..ba1fbece9e 100644 --- a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py +++ b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py @@ -40,7 +40,7 @@ def test_connection_claiming() -> None: - login - password - schema - - extra.databases + - extra.database - extra.driver """ required_values = { From 8450cb5a3447082f3a6eec1a1ce318a3a0169d06 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:37:33 +0530 Subject: [PATCH 6/8] Fix tests --- cosmos/profiles/sqlserver/standard_sqlserver_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py index e781e90c2b..e25676a865 100644 --- a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py +++ b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict from cosmos.profiles import BaseProfileMapping @@ -31,7 +31,7 @@ class StandardSQLServerAuth(BaseProfileMapping): "driver": "extra.driver", } - def _set_default_param(self, profile_dict: dict[str, Any]) -> dict[str, Any]: + def _set_default_param(self, profile_dict: dict[str, Any]) -> Dict[str, Any]: if not profile_dict.get("port"): profile_dict["port"] = self.default_port From 93858eea628dc236b28957a97d8f1b04b978ec19 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:40:22 +0530 Subject: [PATCH 7/8] Fix tests --- cosmos/profiles/sqlserver/standard_sqlserver_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py index e25676a865..bf134dad9f 100644 --- a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py +++ b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py @@ -31,7 +31,7 @@ class StandardSQLServerAuth(BaseProfileMapping): "driver": "extra.driver", } - def _set_default_param(self, profile_dict: dict[str, Any]) -> Dict[str, Any]: + def _set_default_param(self, profile_dict: Dict[str, Any]) -> Dict[str, Any]: if not profile_dict.get("port"): profile_dict["port"] = self.default_port From 091cfe00bf36d591db45b9a13d804da0ad259e68 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Sun, 4 May 2025 11:48:50 +0530 Subject: [PATCH 8/8] Add future annotations --- cosmos/profiles/sqlserver/standard_sqlserver_auth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py index bf134dad9f..e82902e19c 100644 --- a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py +++ b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py @@ -1,4 +1,8 @@ -from typing import Any, Dict +"""Maps Airflow Sqlserver connections using user + password authentication to dbt profiles.""" + +from __future__ import annotations + +from typing import Any from cosmos.profiles import BaseProfileMapping @@ -31,7 +35,7 @@ class StandardSQLServerAuth(BaseProfileMapping): "driver": "extra.driver", } - def _set_default_param(self, profile_dict: Dict[str, Any]) -> Dict[str, Any]: + def _set_default_param(self, profile_dict: dict[str, Any]) -> dict[str, Any]: if not profile_dict.get("port"): profile_dict["port"] = self.default_port