diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index ecf69e7ba0..c5e6ae14e4 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -24,6 +24,7 @@ from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .sqlserver.standard_sqlserver_auth import StandardSQLServerAuth +from .starrocks import StarrocksUserPasswordProfileMapping from .teradata.user_pass import TeradataUserPasswordProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -47,6 +48,7 @@ SnowflakeEncryptedPrivateKeyFilePemProfileMapping, SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, + StarrocksUserPasswordProfileMapping, SparkThriftProfileMapping, ExasolUserPasswordProfileMapping, TeradataUserPasswordProfileMapping, @@ -93,6 +95,7 @@ def get_automatic_profile_mapping( "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", + "StarrocksUserPasswordProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", "TeradataUserPasswordProfileMapping", diff --git a/cosmos/profiles/starrocks/__init__.py b/cosmos/profiles/starrocks/__init__.py new file mode 100644 index 0000000000..52b6dc6850 --- /dev/null +++ b/cosmos/profiles/starrocks/__init__.py @@ -0,0 +1,5 @@ +"""StarRocks Airflow connection -> dbt profile mappings""" + +from .user_pass import StarrocksUserPasswordProfileMapping + +__all__ = ["StarrocksUserPasswordProfileMapping"] diff --git a/cosmos/profiles/starrocks/user_pass.py b/cosmos/profiles/starrocks/user_pass.py new file mode 100644 index 0000000000..92f24a4a8f --- /dev/null +++ b/cosmos/profiles/starrocks/user_pass.py @@ -0,0 +1,67 @@ +"""Maps Airflow Mysql connections using user + password authentication to dbt profiles.""" + +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class StarrocksUserPasswordProfileMapping(BaseProfileMapping): + """ + Maps Airflow MySQL connections using user + password authentication to dbt profiles. + https://docs.getdbt.com/docs/core/connect-data-platform/starrocks-setup + """ + + airflow_connection_type: str = "mysql" # StarRocks support mysql protocol + dbt_profile_type: str = "starrocks" + is_community: bool = True + + required_fields = [ + "host", + "username", + "password", + "port", + "schema", + ] + + secret_fields = [ + "password", + ] + + airflow_param_mapping = { + "host": "host", + "username": "login", + "password": "password", + "port": "port", + "schema": "schema", + } + + @property + def profile(self) -> dict[str, str | int | None]: + """ + Generate the dbt profile configuration for StarRocks. + + Returns: + dict: Profile configuration compatible with dbt-starrocks + """ + profile = { + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + return self.filter_null(profile) + + @property + def mock_profile(self) -> dict[str, Any | None]: + """Gets mock profile.""" + profile_dict = { + **super().mock_profile, + "port": 9030, + } + user_defined_schema = self.profile_args.get("schema") + if user_defined_schema: + profile_dict["schema"] = user_defined_schema + return profile_dict diff --git a/pyproject.toml b/pyproject.toml index 0f832b3fb7..9fdb596d66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dbt-all = [ "dbt-postgres", "dbt-redshift", "dbt-snowflake", + "dbt-starrocks", "dbt-spark", "dbt-sqlserver", "dbt-teradata", @@ -71,6 +72,7 @@ dbt-oracle = ["dbt-oracle"] dbt-postgres = ["dbt-postgres"] dbt-redshift = ["dbt-redshift"] dbt-snowflake = ["dbt-snowflake"] +dbt-starrocks = ["dbt-starrocks"] dbt-spark = ["dbt-spark"] dbt-sqlserver = ["dbt-sqlserver"] dbt-teradata = ["dbt-teradata"] diff --git a/tests/profiles/starrocks/test_starrocks_user_pass.py b/tests/profiles/starrocks/test_starrocks_user_pass.py new file mode 100644 index 0000000000..3489761d33 --- /dev/null +++ b/tests/profiles/starrocks/test_starrocks_user_pass.py @@ -0,0 +1,140 @@ +"""Tests for the starrocks profile mapping.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles.starrocks.user_pass import StarrocksUserPasswordProfileMapping + + +@pytest.fixture() +def mock_starrocks_conn(): # type: ignore + """ + Mocks a StarRocks connection via Airflow MySQL connection (StarRocks speaks MySQL protocol). + """ + conn = Connection( + conn_id="my_starrocks_connection", + conn_type="mysql", + host="my_host", + login="my_user", + password="my_password", + port=9030, + schema="my_database", + ) + + with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_starrocks_conn_custom_port(): # type: ignore + """ + Same as above but with a custom port. + """ + conn = Connection( + conn_id="my_starrocks_connection", + conn_type="mysql", + host="my_host", + login="my_user", + password="my_password", + port=8040, + schema="my_database", + ) + + with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the starrocks profile mapping claims the correct connection type and required fields. + """ + # should only claim when: + # - conn_type == mysql + # - host/login/password/port/schema are present + potential_values = { + "conn_type": "mysql", + "host": "my_host", + "login": "my_user", + "password": "my_password", + "port": 9030, + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": "my_schema"}) + assert not profile_mapping.can_claim_connection() + + # also test when schema is explicitly None in profile_args + conn = Connection(**potential_values) # type: ignore + with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": None}) + assert not profile_mapping.can_claim_connection() + + # if we have them all and provide schema in profile_args, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn): + profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": "my_schema"}) + assert profile_mapping.can_claim_connection() + + +def test_profile_keeps_custom_port(mock_starrocks_conn_custom_port: Connection) -> None: + profile_mapping = StarrocksUserPasswordProfileMapping( + mock_starrocks_conn_custom_port.conn_id, + {"schema": "my_schema"}, + ) + assert profile_mapping.profile["port"] == 8040 + + +def test_profile_args_and_profile(mock_starrocks_conn: Connection) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = StarrocksUserPasswordProfileMapping( + mock_starrocks_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + + assert profile_mapping.profile_args == {"schema": "my_schema"} + + # NOTE: env var name is based on airflow_connection_type (mysql), not dbt_profile_type (starrocks) + assert profile_mapping.profile == { + "type": "starrocks", + "host": mock_starrocks_conn.host, + "username": mock_starrocks_conn.login, + "password": "{{ env_var('COSMOS_CONN_MYSQL_PASSWORD') }}", + "port": mock_starrocks_conn.port, + "schema": "my_schema", + } + + +def test_profile_env_vars(mock_starrocks_conn: Connection) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = StarrocksUserPasswordProfileMapping( + mock_starrocks_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_MYSQL_PASSWORD": mock_starrocks_conn.password, + } + + +def test_mock_profile(mock_starrocks_conn: Connection) -> None: + """ + Tests that the mock profile is generated correctly. + """ + profile_mapping = StarrocksUserPasswordProfileMapping( + mock_starrocks_conn.conn_id, + {"schema": "my_schema"}, + ) + mock_profile = profile_mapping.mock_profile + assert mock_profile["port"] == 9030 + assert mock_profile["schema"] == "my_schema"