Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .sqlserver.standard_sqlserver_auth import StandardSQLServerAuth
from .teradata.user_pass import TeradataUserPasswordProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand Down Expand Up @@ -52,6 +53,7 @@
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
VerticaUserPasswordProfileMapping,
StandardSQLServerAuth,
]


Expand Down Expand Up @@ -96,4 +98,5 @@ def get_automatic_profile_mapping(
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
"VerticaUserPasswordProfileMapping",
"StandardSQLServerAuth",
]
Empty file.
64 changes: 64 additions & 0 deletions cosmos/profiles/sqlserver/standard_sqlserver_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Maps Airflow Sqlserver connections using user + password authentication to dbt profiles."""

from __future__ import annotations

from typing import Any

from cosmos.profiles import BaseProfileMapping


class StandardSQLServerAuth(BaseProfileMapping):

airflow_connection_type: str = "generic"
dbt_profile_type: str = "sqlserver"
default_port = 1433
is_community = True

required_fields = [
"server",
"user",
"schema",
"database",
"driver",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"server": "host",
"user": "login",
"password": "password",
"port": "port",
"schema": "schema",
"database": "extra.database",
"driver": "extra.driver",
}

def _set_default_param(self, profile_dict: dict[str, Any]) -> dict[str, Any]:

if not profile_dict.get("port"):
profile_dict["port"] = self.default_port

return profile_dict

@property
def profile(self) -> dict[str, Any]:
profile_dict = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(self._set_default_param(profile_dict))

@property
def mock_profile(self) -> dict[str, Any]:
"""Gets mock profile."""

profile_dict = {
**super().mock_profile,
}

return self._set_default_param(profile_dict)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dbt-all = [
"dbt-redshift",
"dbt-snowflake",
"dbt-spark",
"dbt-sqlserver",
"dbt-teradata",
"dbt-vertica",
]
Expand All @@ -71,6 +72,7 @@ dbt-postgres = ["dbt-postgres"]
dbt-redshift = ["dbt-redshift"]
dbt-snowflake = ["dbt-snowflake"]
dbt-spark = ["dbt-spark"]
dbt-sqlserver = ["dbt-sqlserver"]
dbt-teradata = ["dbt-teradata"]
dbt-vertica = ["dbt-vertica<=1.5.4"]
openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"]
Expand Down
117 changes: 117 additions & 0 deletions tests/profiles/sqlserver/test_standard_sqlserver_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for the sqlserver profile."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.sqlserver.standard_sqlserver_auth import (
StandardSQLServerAuth,
)


@pytest.fixture()
def mock_sqlserver_conn(): # type: ignore
"""Sets the connection as an environment variable."""
conn = Connection(
conn_id="sqlserver_connection",
conn_type="generic",
host="my_host",
login="my_user",
port=1433,
password="my_password",
schema="dbo",
extra='{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the sqlserver profile mapping claims the correct connection type.

should only claim when:
- conn_type == generic
And the following exist:
- host
- login
- password
- schema
- extra.database
- extra.driver
"""
required_values = {
"conn_type": "generic",
"host": "my_host",
"login": "my_user",
"schema": "dbo",
"password": "pass",
"extra": '{"database": "my_db", "driver": "ODBC Driver 18 for SQL Server"}',
}

def can_claim_with_missing_key(missing_key: str) -> bool:
values = required_values.copy()
del values[missing_key]
conn = Connection(**values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = StandardSQLServerAuth(conn, {})
return profile_mapping.can_claim_connection()

# if we're missing any of the required values, it shouldn't claim
for key in required_values:
assert not can_claim_with_missing_key(key), f"Failed when missing {key}"

# if we have all the required values, it should claim
conn = Connection(**required_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = StandardSQLServerAuth(conn, {})
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_sqlserver_conn: Connection,
) -> None:
"""Tests that the correct profile mapping is selected."""
profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, {})
assert isinstance(profile_mapping, StandardSQLServerAuth)


def test_profile_args(mock_sqlserver_conn: Connection) -> None:
"""Tests that the profile values get set correctly."""
profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={})

assert profile_mapping.profile == {
"type": "sqlserver",
"schema": mock_sqlserver_conn.schema,
"user": mock_sqlserver_conn.login,
"password": "{{ env_var('COSMOS_CONN_GENERIC_PASSWORD') }}",
"driver": mock_sqlserver_conn.extra_dejson["driver"],
"port": mock_sqlserver_conn.port,
"server": mock_sqlserver_conn.host,
"database": mock_sqlserver_conn.extra_dejson["database"],
}


def test_mock_profile() -> None:
"""Tests that the mock_profile values get set correctly."""
profile_mapping = StandardSQLServerAuth("conn_id")

assert profile_mapping.mock_profile == {
"type": "sqlserver",
"server": "mock_value",
"schema": "mock_value",
"database": "mock_value",
"user": "mock_value",
"password": "mock_value",
"driver": "mock_value",
"port": 1433,
}


def test_profile_env_vars(mock_sqlserver_conn: Connection) -> None:
"""Tests that the environment variables get set correctly."""
profile_mapping = get_automatic_profile_mapping(mock_sqlserver_conn.conn_id, profile_args={})
assert profile_mapping.env_vars == {"COSMOS_CONN_GENERIC_PASSWORD": mock_sqlserver_conn.password}