diff --git a/providers/dbt/core/operators.py b/providers/dbt/core/operators.py index 1bfa459002..c41453f60b 100644 --- a/providers/dbt/core/operators.py +++ b/providers/dbt/core/operators.py @@ -2,17 +2,16 @@ import os import shutil -import sys -import yaml from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.hooks.base import BaseHook from airflow.hooks.subprocess import SubprocessHook from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context from airflow.utils.operator_helpers import context_to_airflow_vars +from cosmos.providers.dbt.core.utils.profiles_generator import create_default_profiles, map_profile + class DBTBaseOperator(BaseOperator): @@ -149,63 +148,9 @@ def run_command(self, cmd, env): self.exception_handling(result) return result - def create_default_profiles(self): - profile = { - "postgres_profile": { - "outputs": { - "dev": { - "type": "postgres", - "host": "{{ env_var('POSTGRES_HOST') }}", - "port": "{{ env_var('POSTGRES_PORT') | as_number }}", - "user": "{{ env_var('POSTGRES_USER') }}", - "pass": "{{ env_var('POSTGRES_PASSWORD') }}", - "dbname": "{{ env_var('POSTGRES_DATABASE') }}", - "schema": "{{ env_var('POSTGRES_SCHEMA') }}", - } - }, - "target": "dev", - } - } - # Define the path to the directory and file - directory_path = "/home/astro/.dbt" - file_path = "/home/astro/.dbt/profiles.yml" - - # Create the directory if it does not exist - if not os.path.exists(directory_path): - os.makedirs(directory_path) - - # Create the file if it does not exist - if not os.path.exists(file_path): - print("profiles.yml not found - initializing.") - with open(file_path, "w") as file: - yaml.dump(profile, file) - print("done") - else: - print("profiles.yml found - skipping") - - def map_profile(self): - conn = BaseHook().get_connection(self.conn_id) - - if conn.conn_type == "postgres": - profile = "postgres_profile" - profile_vars = { - "POSTGRES_HOST": conn.host, - "POSTGRES_USER": conn.login, - "POSTGRES_PASSWORD": conn.password, - "POSTGRES_DATABASE": conn.schema, - "POSTGRES_PORT": str(conn.port), - "POSTGRES_SCHEMA": self.schema, - } - - else: - print(f"Connection type {conn.type} is not yet supported.", file=sys.stderr) - sys.exit(1) - - return profile, profile_vars - - def build_and_run_cmd(self, env): - self.create_default_profiles() - profile, profile_vars = self.map_profile() + def build_and_run_cmd(self, env): + create_default_profiles() + profile, profile_vars = map_profile(self.conn_id, self.schema) env = env | profile_vars cmd = self.build_command() + ["--profile", profile] result = self.run_command(cmd=cmd, env=env) diff --git a/providers/dbt/core/profiles/postgres.py b/providers/dbt/core/profiles/postgres.py new file mode 100644 index 0000000000..75cb461d28 --- /dev/null +++ b/providers/dbt/core/profiles/postgres.py @@ -0,0 +1,14 @@ +postgres_profile = { + "outputs": { + "dev": { + "type": "postgres", + "host": "{{ env_var('POSTGRES_HOST') }}", + "port": "{{ env_var('POSTGRES_PORT') | as_number }}", + "user": "{{ env_var('POSTGRES_USER') }}", + "pass": "{{ env_var('POSTGRES_PASSWORD') }}", + "dbname": "{{ env_var('POSTGRES_DATABASE') }}", + "schema": "{{ env_var('POSTGRES_SCHEMA') }}", + } + }, + "target": "dev", +} diff --git a/providers/dbt/core/profiles/snowflake.py b/providers/dbt/core/profiles/snowflake.py new file mode 100644 index 0000000000..e8ac0e01f5 --- /dev/null +++ b/providers/dbt/core/profiles/snowflake.py @@ -0,0 +1,16 @@ +snowflake_profile = { + "target": "dev", + "outputs": { + "dev": { + "type": "snowflake", + "account": "{{ env_var('SNOWFLAKE_ACCOUNT') }}", + "user": "{{ env_var('SNOWFLAKE_USER') }}", + "password": "{{ env_var('SNOWFLAKE_PASSWORD') }}", + "role": "{{ env_var('SNOWFLAKE_ROLE') }}", + "database": "{{ env_var('SNOWFLAKE_DATABASE') }}", + "warehouse": "{{ env_var('SNOWFLAKE_WAREHOUSE') }}", + "schema": "{{ env_var('SNOWFLAKE_SCHEMA') }}", + "client_session_keep_alive": False, + } + }, +} diff --git a/providers/dbt/core/utils/profiles_generator.py b/providers/dbt/core/utils/profiles_generator.py new file mode 100644 index 0000000000..eb33c18ea2 --- /dev/null +++ b/providers/dbt/core/utils/profiles_generator.py @@ -0,0 +1,66 @@ +import os +import sys + +import yaml +from airflow.hooks.base import BaseHook +from airflow.models.connection import Connection + +from cosmos.providers.dbt.core.profiles.postgres import postgres_profile +from cosmos.providers.dbt.core.profiles.snowflake import snowflake_profile + + +def create_default_profiles(): + profiles = {"postgres_profile": postgres_profile, "snowflake_profile": snowflake_profile} + # Define the path to the directory and file + directory_path = "/home/astro/.dbt" + file_path = "/home/astro/.dbt/profiles.yml" + + # Create the directory if it does not exist + if not os.path.exists(directory_path): + os.makedirs(directory_path) + + # Create the file if it does not exist + if not os.path.exists(file_path): + print("profiles.yml not found - initializing.") + with open(file_path, "w") as file: + yaml.dump(profiles, file) + print("done") + else: + print("profiles.yml found - skipping") + + +def create_profile_vars(conn: Connection, schema_override): + if conn.conn_type == "postgres": + profile = "postgres_profile" + profile_vars = { + "POSTGRES_HOST": conn.host, + "POSTGRES_USER": conn.login, + "POSTGRES_PASSWORD": conn.password, + "POSTGRES_DATABASE": conn.schema, + "POSTGRES_PORT": str(conn.port), + "POSTGRES_SCHEMA": schema_override, + } + + elif conn.conn_type == "snowflake": + profile = "snowflake_profile" + profile_vars = { + "SNOWFLAKE_USER": conn.login, + "SNOWFLAKE_PASSWORD": conn.password, + "SNOWFLAKE_ACCOUNT": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}", + "SNOWFLAKE_ROLE": conn.extra_dejson.get("role"), + "SNOWFLAKE_DATABASE": conn.extra_dejson.get("database"), + "SNOWFLAKE_WAREHOUSE": conn.extra_dejson.get("warehouse"), + "SNOWFLAKE_SCHEMA": conn.schema, + } + + else: + print(f"Connection type {conn.type} is not yet supported.", file=sys.stderr) + sys.exit(1) + + return profile, profile_vars + + +def map_profile(conn_id, schema): + conn = BaseHook().get_connection(conn_id) + profile, profile_vars = create_profile_vars(conn, schema_override=schema) + return profile, profile_vars