diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index d8751899be..c3234bc872 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -22,6 +22,7 @@ from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping +from .sqlserver.standard_sqlserver_auth import StandardSQLServerAuth from .teradata.user_pass import TeradataUserPasswordProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -52,6 +53,7 @@ TrinoCertificateProfileMapping, TrinoJWTProfileMapping, VerticaUserPasswordProfileMapping, + StandardSQLServerAuth, ] @@ -96,4 +98,5 @@ def get_automatic_profile_mapping( "TrinoCertificateProfileMapping", "TrinoJWTProfileMapping", "VerticaUserPasswordProfileMapping", + "StandardSQLServerAuth", ] diff --git a/cosmos/profiles/sqlserver/__init__.py b/cosmos/profiles/sqlserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cosmos/profiles/sqlserver/standard_sqlserver_auth.py b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py new file mode 100644 index 0000000000..e82902e19c --- /dev/null +++ b/cosmos/profiles/sqlserver/standard_sqlserver_auth.py @@ -0,0 +1,64 @@ +"""Maps Airflow Sqlserver connections using user + password authentication to dbt profiles.""" + +from __future__ import annotations + +from typing import Any + +from cosmos.profiles import BaseProfileMapping + + +class StandardSQLServerAuth(BaseProfileMapping): + + airflow_connection_type: str = "generic" + dbt_profile_type: str = "sqlserver" + default_port = 1433 + is_community = True + + required_fields = [ + "server", + "user", + "schema", + "database", + "driver", + "password", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "server": "host", + "user": "login", + "password": "password", + "port": "port", + "schema": "schema", + "database": "extra.database", + "driver": "extra.driver", + } + + def _set_default_param(self, profile_dict: dict[str, Any]) -> dict[str, Any]: + + if not profile_dict.get("port"): + profile_dict["port"] = self.default_port + + return profile_dict + + @property + def profile(self) -> dict[str, Any]: + profile_dict = { + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + return self.filter_null(self._set_default_param(profile_dict)) + + @property + def mock_profile(self) -> dict[str, Any]: + """Gets mock profile.""" + + profile_dict = { + **super().mock_profile, + } + + return self._set_default_param(profile_dict) diff --git a/pyproject.toml b/pyproject.toml index 9f648ecaa6..23a7809341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dbt-all = [ "dbt-redshift", "dbt-snowflake", "dbt-spark", + "dbt-sqlserver", "dbt-teradata", "dbt-vertica", ] @@ -71,6 +72,7 @@ dbt-postgres = ["dbt-postgres"] dbt-redshift = ["dbt-redshift"] dbt-snowflake = ["dbt-snowflake"] dbt-spark = ["dbt-spark"] +dbt-sqlserver = ["dbt-sqlserver"] dbt-teradata = ["dbt-teradata"] dbt-vertica = ["dbt-vertica<=1.5.4"] openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"] diff --git a/tests/profiles/sqlserver/test_standard_sqlserver_auth.py b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py new file mode 100644 index 0000000000..ba1fbece9e --- /dev/null +++ b/tests/profiles/sqlserver/test_standard_sqlserver_auth.py @@ -0,0 +1,117 @@ +"""Tests for the sqlserver profile.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.sqlserver.standard_sqlserver_auth import ( + StandardSQLServerAuth, +) + + +@pytest.fixture() +def mock_sqlserver_conn(): # type: ignore + """Sets the connection as an environment variable.""" + conn = Connection( + conn_id="sqlserver_connection", + conn_type="generic", + host="my_host", + login="my_user", + port=1433, + password="my_password", + schema="dbo", + extra='{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the sqlserver profile mapping claims the correct connection type. + + should only claim when: + - conn_type == generic + And the following exist: + - host + - login + - password + - schema + - extra.database + - extra.driver + """ + required_values = { + "conn_type": "generic", + "host": "my_host", + "login": "my_user", + "schema": "dbo", + "password": "pass", + "extra": '{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}', + } + + def can_claim_with_missing_key(missing_key: str) -> bool: + values = required_values.copy() + del values[missing_key] + conn = Connection(**values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StandardSQLServerAuth(conn, {}) + return profile_mapping.can_claim_connection() + + # if we're missing any of the required values, it shouldn't claim + for key in required_values: + assert not can_claim_with_missing_key(key), f"Failed when missing {key}" + + # if we have all the required values, it should claim + conn = Connection(**required_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StandardSQLServerAuth(conn, {}) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_sqlserver_conn: Connection, +) -> None: + """Tests that the correct profile mapping is selected.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, {}) + assert isinstance(profile_mapping, StandardSQLServerAuth) + + +def test_profile_args(mock_sqlserver_conn: Connection) -> None: + """Tests that the profile values get set correctly.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={}) + + assert profile_mapping.profile == { + "type": "sqlserver", + "schema": mock_sqlserver_conn.schema, + "user": mock_sqlserver_conn.login, + "password": "{{ env_var('COSMOS_CONN_GENERIC_PASSWORD') }}", + "driver": mock_sqlserver_conn.extra_dejson["driver"], + "port": mock_sqlserver_conn.port, + "server": mock_sqlserver_conn.host, + "database": mock_sqlserver_conn.extra_dejson["database"], + } + + +def test_mock_profile() -> None: + """Tests that the mock_profile values get set correctly.""" + profile_mapping = StandardSQLServerAuth("conn_id") + + assert profile_mapping.mock_profile == { + "type": "sqlserver", + "server": "mock_value", + "schema": "mock_value", + "database": "mock_value", + "user": "mock_value", + "password": "mock_value", + "driver": "mock_value", + "port": 1433, + } + + +def test_profile_env_vars(mock_sqlserver_conn: Connection) -> None: + """Tests that the environment variables get set correctly.""" + profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={}) + assert profile_mapping.env_vars == {"COSMOS_CONN_GENERIC_PASSWORD": mock_sqlserver_conn.password}