Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 5 additions & 60 deletions providers/dbt/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import os
import shutil
import sys

import yaml
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.hooks.subprocess import SubprocessHook
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.operator_helpers import context_to_airflow_vars

from cosmos.providers.dbt.core.utils.profiles_generator import create_default_profiles, map_profile


class DBTBaseOperator(BaseOperator):

Expand Down Expand Up @@ -149,63 +148,9 @@ def run_command(self, cmd, env):
self.exception_handling(result)
return result

def create_default_profiles(self):
profile = {
"postgres_profile": {
"outputs": {
"dev": {
"type": "postgres",
"host": "{{ env_var('POSTGRES_HOST') }}",
"port": "{{ env_var('POSTGRES_PORT') | as_number }}",
"user": "{{ env_var('POSTGRES_USER') }}",
"pass": "{{ env_var('POSTGRES_PASSWORD') }}",
"dbname": "{{ env_var('POSTGRES_DATABASE') }}",
"schema": "{{ env_var('POSTGRES_SCHEMA') }}",
}
},
"target": "dev",
}
}
# Define the path to the directory and file
directory_path = "/home/astro/.dbt"
file_path = "/home/astro/.dbt/profiles.yml"

# Create the directory if it does not exist
if not os.path.exists(directory_path):
os.makedirs(directory_path)

# Create the file if it does not exist
if not os.path.exists(file_path):
print("profiles.yml not found - initializing.")
with open(file_path, "w") as file:
yaml.dump(profile, file)
print("done")
else:
print("profiles.yml found - skipping")

def map_profile(self):
conn = BaseHook().get_connection(self.conn_id)

if conn.conn_type == "postgres":
profile = "postgres_profile"
profile_vars = {
"POSTGRES_HOST": conn.host,
"POSTGRES_USER": conn.login,
"POSTGRES_PASSWORD": conn.password,
"POSTGRES_DATABASE": conn.schema,
"POSTGRES_PORT": str(conn.port),
"POSTGRES_SCHEMA": self.schema,
}

else:
print(f"Connection type {conn.type} is not yet supported.", file=sys.stderr)
sys.exit(1)

return profile, profile_vars

def build_and_run_cmd(self, env):
self.create_default_profiles()
profile, profile_vars = self.map_profile()
def build_and_run_cmd(self, env):
create_default_profiles()
profile, profile_vars = map_profile(self.conn_id, self.schema)
env = env | profile_vars
cmd = self.build_command() + ["--profile", profile]
result = self.run_command(cmd=cmd, env=env)
Expand Down
14 changes: 14 additions & 0 deletions providers/dbt/core/profiles/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
postgres_profile = {
"outputs": {
"dev": {
"type": "postgres",
"host": "{{ env_var('POSTGRES_HOST') }}",
"port": "{{ env_var('POSTGRES_PORT') | as_number }}",
"user": "{{ env_var('POSTGRES_USER') }}",
"pass": "{{ env_var('POSTGRES_PASSWORD') }}",
"dbname": "{{ env_var('POSTGRES_DATABASE') }}",
"schema": "{{ env_var('POSTGRES_SCHEMA') }}",
}
},
"target": "dev",
}
16 changes: 16 additions & 0 deletions providers/dbt/core/profiles/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
snowflake_profile = {
"target": "dev",
"outputs": {
"dev": {
"type": "snowflake",
"account": "{{ env_var('SNOWFLAKE_ACCOUNT') }}",
"user": "{{ env_var('SNOWFLAKE_USER') }}",
"password": "{{ env_var('SNOWFLAKE_PASSWORD') }}",
"role": "{{ env_var('SNOWFLAKE_ROLE') }}",
"database": "{{ env_var('SNOWFLAKE_DATABASE') }}",
"warehouse": "{{ env_var('SNOWFLAKE_WAREHOUSE') }}",
"schema": "{{ env_var('SNOWFLAKE_SCHEMA') }}",
"client_session_keep_alive": False,
}
},
}
66 changes: 66 additions & 0 deletions providers/dbt/core/utils/profiles_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys

import yaml
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection

from cosmos.providers.dbt.core.profiles.postgres import postgres_profile
from cosmos.providers.dbt.core.profiles.snowflake import snowflake_profile


def create_default_profiles():
profiles = {"postgres_profile": postgres_profile, "snowflake_profile": snowflake_profile}
# Define the path to the directory and file
directory_path = "/home/astro/.dbt"
file_path = "/home/astro/.dbt/profiles.yml"

# Create the directory if it does not exist
if not os.path.exists(directory_path):
os.makedirs(directory_path)

# Create the file if it does not exist
if not os.path.exists(file_path):
print("profiles.yml not found - initializing.")
with open(file_path, "w") as file:
yaml.dump(profiles, file)
print("done")
else:
print("profiles.yml found - skipping")


def create_profile_vars(conn: Connection, schema_override):
if conn.conn_type == "postgres":
profile = "postgres_profile"
profile_vars = {
"POSTGRES_HOST": conn.host,
"POSTGRES_USER": conn.login,
"POSTGRES_PASSWORD": conn.password,
"POSTGRES_DATABASE": conn.schema,
"POSTGRES_PORT": str(conn.port),
"POSTGRES_SCHEMA": schema_override,
}

elif conn.conn_type == "snowflake":
profile = "snowflake_profile"
profile_vars = {
"SNOWFLAKE_USER": conn.login,
"SNOWFLAKE_PASSWORD": conn.password,
"SNOWFLAKE_ACCOUNT": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}",
"SNOWFLAKE_ROLE": conn.extra_dejson.get("role"),
"SNOWFLAKE_DATABASE": conn.extra_dejson.get("database"),
"SNOWFLAKE_WAREHOUSE": conn.extra_dejson.get("warehouse"),
"SNOWFLAKE_SCHEMA": conn.schema,
}

else:
print(f"Connection type {conn.type} is not yet supported.", file=sys.stderr)
sys.exit(1)

return profile, profile_vars


def map_profile(conn_id, schema):
conn = BaseHook().get_connection(conn_id)
profile, profile_vars = create_profile_vars(conn, schema_override=schema)
return profile, profile_vars