Skip to content

Commit

Permalink
Use Snowflake provider to build connection string
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanstillfront committed Mar 1, 2023
1 parent c4aae5b commit 82e2741
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 60 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ge-provider-3.10.9
102 changes: 76 additions & 26 deletions great_expectations_provider/operators/great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import great_expectations as ge
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(
# Update data_asset_name to be only the table
self.data_asset_name = asset_list[1]

def make_connection_string(self) -> str:
def make_connection_configuration(self) -> Dict[str, str]:
"""Builds connection strings based off existing Airflow connections. Only supports necessary extras."""
uri_string = ""
if not self.conn:
Expand All @@ -251,48 +252,97 @@ def make_connection_string(self) -> str:
odbc_connector = "mssql+pyodbc"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{self.schema}" # noqa
elif conn_type == "snowflake":
snowflake_account = (
self.conn.extra_dejson.get("account") or self.conn.extra_dejson["extra__snowflake__account"]
)
snowflake_region = (
self.conn.extra_dejson.get("region") or self.conn.extra_dejson["extra__snowflake__region"]
)
snowflake_database = (
self.conn.extra_dejson.get("database") or self.conn.extra_dejson["extra__snowflake__database"]
)
snowflake_warehouse = (
self.conn.extra_dejson.get("warehouse") or self.conn.extra_dejson["extra__snowflake__warehouse"]
)
snowflake_role = self.conn.extra_dejson.get("role") or self.conn.extra_dejson["extra__snowflake__role"]
try:
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

hook = SnowflakeHook(snowflake_conn_id=self.conn_id)

# Support the operator overriding the schema
# which is necessary for temp tables.
hook.schema = self.schema or hook.schema

conn = hook.get_connection(self.conn_id)
engine = hook.get_sqlalchemy_engine()

url = engine.url.render_as_string(hide_password=False)

private_key_file = conn.extra_dejson.get(
"extra__snowflake__private_key_file"
) or conn.extra_dejson.get("private_key_file")

if private_key_file:
private_key_pem = Path(private_key_file).read_bytes()

passphrase = None
if conn.password:
passphrase = conn.password.strip().encode()

p_key = serialization.load_pem_private_key(
private_key_pem, password=passphrase, backend=default_backend()
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return {
# Unfortunately GE uses deepcopy when instantiating the SqlAlchemyExecutionEngine
# which uses pickle and SAEngine is not pickleable.
# "engine": engine,
"url": url,
"connect_args": {
"private_key": pkb,
},
}

uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{snowflake_account}.{snowflake_region}/{snowflake_database}/{self.schema}?warehouse={snowflake_warehouse}&role={snowflake_role}" # noqa
return {"url": url}

except ImportError:
self.log.warning(
(
"Snowflake provider package not available, "
"attempt to manually build connection. "
"Key-based auth is not supported."
)
)

snowflake_account = (
self.conn.extra_dejson.get("account") or self.conn.extra_dejson["extra__snowflake__account"]
)
snowflake_region = (
self.conn.extra_dejson.get("region") or self.conn.extra_dejson["extra__snowflake__region"]
)
snowflake_database = (
self.conn.extra_dejson.get("database") or self.conn.extra_dejson["extra__snowflake__database"]
)
snowflake_warehouse = (
self.conn.extra_dejson.get("warehouse") or self.conn.extra_dejson["extra__snowflake__warehouse"]
)
snowflake_role = self.conn.extra_dejson.get("role") or self.conn.extra_dejson["extra__snowflake__role"]

uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{snowflake_account}.{snowflake_region}/{snowflake_database}/{self.schema}?warehouse={snowflake_warehouse}&role={snowflake_role}" # noqa

# private_key_file is optional, see:
# https://docs.snowflake.com/en/user-guide/sqlalchemy.html#key-pair-authentication-support
snowflake_private_key_file = self.conn.extra_dejson.get(
"private_key_file", self.conn.extra_dejson.get("extra__snowflake__private_key_file")
)
if snowflake_private_key_file:
uri_string += f"&private_key_file={snowflake_private_key_file}"
elif conn_type == "gcpbigquery":
uri_string = f"{self.conn.host}{self.schema}"
elif conn_type == "sqlite":
uri_string = f"sqlite:///{self.conn.host}"
# TODO: Add Athena and Trino support if possible
else:
raise ValueError(f"Conn type: {conn_type} is not supported.")
return uri_string
return {"connection_string": uri_string}

def build_configured_sql_datasource_config_from_conn_id(
self,
) -> Datasource:
conn_str = self.make_connection_string()
datasource_config = {
"name": f"{self.conn.conn_id}_configured_sql_datasource",
"execution_engine": {
"module_name": "great_expectations.execution_engine",
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": conn_str,
**self.make_connection_configuration(),
},
"data_connectors": {
"default_configured_asset_sql_data_connector": {
Expand Down Expand Up @@ -327,7 +377,7 @@ def build_runtime_sql_datasource_config_from_conn_id(
"execution_engine": {
"module_name": "great_expectations.execution_engine",
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": self.make_connection_string(),
**self.make_connection_configuration(),
},
"data_connectors": {
"default_runtime_data_connector": {
Expand Down
Empty file added host
Empty file.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ install_requires =
tests =
parameterized
pytest
pytest-mock
apache-airflow-providers-snowflake>=3.3.0

[options.entry_points]
apache_airflow_provider=
Expand Down
Loading

0 comments on commit 82e2741

Please sign in to comment.