Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .sqlserver.standard_sqlserver_auth import StandardSQLServerAuth
from .starrocks import StarrocksUserPasswordProfileMapping
from .teradata.user_pass import TeradataUserPasswordProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -47,6 +48,7 @@
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
StarrocksUserPasswordProfileMapping,
SparkThriftProfileMapping,
ExasolUserPasswordProfileMapping,
TeradataUserPasswordProfileMapping,
Expand Down Expand Up @@ -93,6 +95,7 @@ def get_automatic_profile_mapping(
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"StarrocksUserPasswordProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TeradataUserPasswordProfileMapping",
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/starrocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""StarRocks Airflow connection -> dbt profile mappings"""

from .user_pass import StarrocksUserPasswordProfileMapping

__all__ = ["StarrocksUserPasswordProfileMapping"]
67 changes: 67 additions & 0 deletions cosmos/profiles/starrocks/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Maps Airflow Mysql connections using user + password authentication to dbt profiles."""

from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class StarrocksUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow MySQL connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/docs/core/connect-data-platform/starrocks-setup
"""

airflow_connection_type: str = "mysql" # StarRocks support mysql protocol
dbt_profile_type: str = "starrocks"
is_community: bool = True

required_fields = [
"host",
"username",
"password",
"port",
"schema",
]

secret_fields = [
"password",
]

airflow_param_mapping = {
"host": "host",
"username": "login",
"password": "password",
"port": "port",
"schema": "schema",
}

@property
def profile(self) -> dict[str, str | int | None]:
"""
Generate the dbt profile configuration for StarRocks.

Returns:
dict: Profile configuration compatible with dbt-starrocks
"""
profile = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"""Gets mock profile."""
profile_dict = {
**super().mock_profile,
"port": 9030,
}
user_defined_schema = self.profile_args.get("schema")
if user_defined_schema:
profile_dict["schema"] = user_defined_schema
return profile_dict
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dbt-all = [
"dbt-postgres",
"dbt-redshift",
"dbt-snowflake",
"dbt-starrocks",
"dbt-spark",
"dbt-sqlserver",
"dbt-teradata",
Expand All @@ -71,6 +72,7 @@ dbt-oracle = ["dbt-oracle"]
dbt-postgres = ["dbt-postgres"]
dbt-redshift = ["dbt-redshift"]
dbt-snowflake = ["dbt-snowflake"]
dbt-starrocks = ["dbt-starrocks"]
dbt-spark = ["dbt-spark"]
dbt-sqlserver = ["dbt-sqlserver"]
dbt-teradata = ["dbt-teradata"]
Expand Down
140 changes: 140 additions & 0 deletions tests/profiles/starrocks/test_starrocks_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Tests for the starrocks profile mapping."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles.starrocks.user_pass import StarrocksUserPasswordProfileMapping


@pytest.fixture()
def mock_starrocks_conn(): # type: ignore
"""
Mocks a StarRocks connection via Airflow MySQL connection (StarRocks speaks MySQL protocol).
"""
conn = Connection(
conn_id="my_starrocks_connection",
conn_type="mysql",
host="my_host",
login="my_user",
password="my_password",
port=9030,
schema="my_database",
)

with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_starrocks_conn_custom_port(): # type: ignore
"""
Same as above but with a custom port.
"""
conn = Connection(
conn_id="my_starrocks_connection",
conn_type="mysql",
host="my_host",
login="my_user",
password="my_password",
port=8040,
schema="my_database",
)

with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the starrocks profile mapping claims the correct connection type and required fields.
"""
# should only claim when:
# - conn_type == mysql
# - host/login/password/port/schema are present
potential_values = {
"conn_type": "mysql",
"host": "my_host",
"login": "my_user",
"password": "my_password",
"port": 9030,
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn):
profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert not profile_mapping.can_claim_connection()

# also test when schema is explicitly None in profile_args
conn = Connection(**potential_values) # type: ignore
with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn):
profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": None})
assert not profile_mapping.can_claim_connection()

# if we have them all and provide schema in profile_args, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("cosmos.profiles.base.BaseHook.get_connection", return_value=conn):
profile_mapping = StarrocksUserPasswordProfileMapping(conn, {"schema": "my_schema"})
assert profile_mapping.can_claim_connection()


def test_profile_keeps_custom_port(mock_starrocks_conn_custom_port: Connection) -> None:
profile_mapping = StarrocksUserPasswordProfileMapping(
mock_starrocks_conn_custom_port.conn_id,
{"schema": "my_schema"},
)
assert profile_mapping.profile["port"] == 8040


def test_profile_args_and_profile(mock_starrocks_conn: Connection) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = StarrocksUserPasswordProfileMapping(
mock_starrocks_conn.conn_id,
profile_args={"schema": "my_schema"},
)

assert profile_mapping.profile_args == {"schema": "my_schema"}

# NOTE: env var name is based on airflow_connection_type (mysql), not dbt_profile_type (starrocks)
assert profile_mapping.profile == {
"type": "starrocks",
"host": mock_starrocks_conn.host,
"username": mock_starrocks_conn.login,
"password": "{{ env_var('COSMOS_CONN_MYSQL_PASSWORD') }}",
"port": mock_starrocks_conn.port,
"schema": "my_schema",
}


def test_profile_env_vars(mock_starrocks_conn: Connection) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = StarrocksUserPasswordProfileMapping(
mock_starrocks_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_MYSQL_PASSWORD": mock_starrocks_conn.password,
}


def test_mock_profile(mock_starrocks_conn: Connection) -> None:
"""
Tests that the mock profile is generated correctly.
"""
profile_mapping = StarrocksUserPasswordProfileMapping(
mock_starrocks_conn.conn_id,
{"schema": "my_schema"},
)
mock_profile = profile_mapping.mock_profile
assert mock_profile["port"] == 9030
assert mock_profile["schema"] == "my_schema"