Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schema parameter and data_asset_name parsing #75

Merged
merged 4 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions great_expectations_provider/operators/great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class GreatExpectationsOperator(BaseOperator):
:type return_json_dict: bool
:param use_open_lineage: If True (default), creates an OpenLineage action if an OpenLineage environment is found
:type use_open_lineage: bool
:param schema: If provided, overwrites the default schema provded by the connection
:type schema: Optional[str]
"""

ui_color = "#AFEEEE"
Expand Down Expand Up @@ -144,6 +146,7 @@ def __init__(
fail_task_on_validation_failure: bool = True,
return_json_dict: bool = False,
use_open_lineage: bool = True,
schema: Optional[str] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -170,6 +173,7 @@ def __init__(
self.is_dataframe = True if self.dataframe_to_validate is not None else False
self.datasource: Optional[Datasource] = None
self.batch_request: Optional[BatchRequestBase] = None
self.schema = schema

if self.is_dataframe and self.query_to_validate:
raise ValueError(
Expand Down Expand Up @@ -213,11 +217,21 @@ def __init__(
if isinstance(self.checkpoint_config, CheckpointConfig):
self.checkpoint_config = deep_filter_properties_iterable(properties=self.checkpoint_config.to_dict())

# If a schema is passed as part of the data_asset_name, use that schema
if self.data_asset_name and "." in self.data_asset_name:
# Assume data_asset_name is in the form "SCHEMA.TABLE"
# Schema parameter always takes priority
asset_list = self.data_asset_name.split(".")
self.schema = self.schema or asset_list[0]
# Update data_asset_name to be only the table
self.data_asset_name = asset_list[1]

def make_connection_string(self) -> str:
"""Builds connection strings based off existing Airflow connections. Only supports necessary extras."""
uri_string = ""
if not self.conn:
raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}")
schema = self.schema or self.conn.schema
conn_type = self.conn.conn_type
if conn_type in ("redshift", "postgres", "mysql", "mssql"):
odbc_connector = ""
Expand All @@ -227,11 +241,11 @@ def make_connection_string(self) -> str:
odbc_connector = "mysql"
else:
odbc_connector = "mssql+pyodbc"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{self.conn.schema}" # noqa
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{schema}" # noqa
elif conn_type == "snowflake":
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{self.conn.schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
elif conn_type == "gcpbigquery":
uri_string = f"{self.conn.host}{self.conn.schema}"
uri_string = f"{self.conn.host}{schema}"
elif conn_type == "sqlite":
uri_string = f"sqlite:///{self.conn.host}"
# TODO: Add Athena and Trino support if possible
Expand Down
63 changes: 63 additions & 0 deletions tests/operators/test_great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,69 @@ def test_great_expectations_operator__make_connection_string_sqlite():
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_schema_parameter():
test_conn_str = (
"snowflake://user:[email protected]/database/test_schema_parameter?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
schema="test_schema_parameter",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_data_asset_name_schema_parse():
test_conn_str = (
"snowflake://user:[email protected]/database/test_schema?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str
assert operator.data_asset_name == "test_table"


def test_great_expectations_operator__make_connection_string_raise_error():
operator = GreatExpectationsOperator(
task_id="task_id",
Expand Down