Skip to content

Commit

Permalink
Add schema parameter and data_asset_name parsing (#75)
Browse files Browse the repository at this point in the history
* Add schema parameter and data_asset_name parsing

A new schema parameter will overwrite the supplied connection's
schema, if it exists. If a data_asset_name is provided in the form
'schema.table', the schema will be parsed out into self.schema and
data_asset_name will be updated to only be the table. This will also
override the supplied schema in the connection.

* Fix new test name

* Update great_expectations_provider/operators/great_expectations.py

Co-authored-by: Kaxil Naik <[email protected]>

* Remove conn param as it was a test from a different branch.

Co-authored-by: Kaxil Naik <[email protected]>
  • Loading branch information
denimalpaca and kaxil authored Dec 8, 2022
1 parent 981f586 commit 21bdf99
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
20 changes: 17 additions & 3 deletions great_expectations_provider/operators/great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class GreatExpectationsOperator(BaseOperator):
:type return_json_dict: bool
:param use_open_lineage: If True (default), creates an OpenLineage action if an OpenLineage environment is found
:type use_open_lineage: bool
:param schema: If provided, overwrites the default schema provded by the connection
:type schema: Optional[str]
"""

ui_color = "#AFEEEE"
Expand Down Expand Up @@ -144,6 +146,7 @@ def __init__(
fail_task_on_validation_failure: bool = True,
return_json_dict: bool = False,
use_open_lineage: bool = True,
schema: Optional[str] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -170,6 +173,7 @@ def __init__(
self.is_dataframe = True if self.dataframe_to_validate is not None else False
self.datasource: Optional[Datasource] = None
self.batch_request: Optional[BatchRequestBase] = None
self.schema = schema

if self.is_dataframe and self.query_to_validate:
raise ValueError(
Expand Down Expand Up @@ -213,11 +217,21 @@ def __init__(
if isinstance(self.checkpoint_config, CheckpointConfig):
self.checkpoint_config = deep_filter_properties_iterable(properties=self.checkpoint_config.to_dict())

# If a schema is passed as part of the data_asset_name, use that schema
if self.data_asset_name and "." in self.data_asset_name:
# Assume data_asset_name is in the form "SCHEMA.TABLE"
# Schema parameter always takes priority
asset_list = self.data_asset_name.split(".")
self.schema = self.schema or asset_list[0]
# Update data_asset_name to be only the table
self.data_asset_name = asset_list[1]

def make_connection_string(self) -> str:
"""Builds connection strings based off existing Airflow connections. Only supports necessary extras."""
uri_string = ""
if not self.conn:
raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}")
schema = self.schema or self.conn.schema
conn_type = self.conn.conn_type
if conn_type in ("redshift", "postgres", "mysql", "mssql"):
odbc_connector = ""
Expand All @@ -227,11 +241,11 @@ def make_connection_string(self) -> str:
odbc_connector = "mysql"
else:
odbc_connector = "mssql+pyodbc"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{self.conn.schema}" # noqa
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{schema}" # noqa
elif conn_type == "snowflake":
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{self.conn.schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa
elif conn_type == "gcpbigquery":
uri_string = f"{self.conn.host}{self.conn.schema}"
uri_string = f"{self.conn.host}{schema}"
elif conn_type == "sqlite":
uri_string = f"sqlite:///{self.conn.host}"
# TODO: Add Athena and Trino support if possible
Expand Down
63 changes: 63 additions & 0 deletions tests/operators/test_great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,69 @@ def test_great_expectations_operator__make_connection_string_sqlite():
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_schema_parameter():
test_conn_str = (
"snowflake://user:[email protected]/database/test_schema_parameter?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
schema="test_schema_parameter",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str


def test_great_expectations_operator__make_connection_string_data_asset_name_schema_parse():
test_conn_str = (
"snowflake://user:[email protected]/database/test_schema?warehouse=warehouse&role=role"
)
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_schema.test_table",
conn_id="snowflake_default",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
},
)
operator.conn_type = operator.conn.conn_type
assert operator.make_connection_string() == test_conn_str
assert operator.data_asset_name == "test_table"


def test_great_expectations_operator__make_connection_string_raise_error():
operator = GreatExpectationsOperator(
task_id="task_id",
Expand Down

0 comments on commit 21bdf99

Please sign in to comment.