diff --git a/great_expectations_provider/operators/great_expectations.py b/great_expectations_provider/operators/great_expectations.py index 6f28cae..0eae2a1 100644 --- a/great_expectations_provider/operators/great_expectations.py +++ b/great_expectations_provider/operators/great_expectations.py @@ -111,6 +111,8 @@ class GreatExpectationsOperator(BaseOperator): :type return_json_dict: bool :param use_open_lineage: If True (default), creates an OpenLineage action if an OpenLineage environment is found :type use_open_lineage: bool + :param schema: If provided, overwrites the default schema provded by the connection + :type schema: Optional[str] """ ui_color = "#AFEEEE" @@ -144,6 +146,7 @@ def __init__( fail_task_on_validation_failure: bool = True, return_json_dict: bool = False, use_open_lineage: bool = True, + schema: Optional[str] = None, *args, **kwargs, ) -> None: @@ -170,6 +173,7 @@ def __init__( self.is_dataframe = True if self.dataframe_to_validate is not None else False self.datasource: Optional[Datasource] = None self.batch_request: Optional[BatchRequestBase] = None + self.schema = schema if self.is_dataframe and self.query_to_validate: raise ValueError( @@ -213,11 +217,21 @@ def __init__( if isinstance(self.checkpoint_config, CheckpointConfig): self.checkpoint_config = deep_filter_properties_iterable(properties=self.checkpoint_config.to_dict()) + # If a schema is passed as part of the data_asset_name, use that schema + if self.data_asset_name and "." in self.data_asset_name: + # Assume data_asset_name is in the form "SCHEMA.TABLE" + # Schema parameter always takes priority + asset_list = self.data_asset_name.split(".") + self.schema = self.schema or asset_list[0] + # Update data_asset_name to be only the table + self.data_asset_name = asset_list[1] + def make_connection_string(self) -> str: """Builds connection strings based off existing Airflow connections. Only supports necessary extras.""" uri_string = "" if not self.conn: raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}") + schema = self.schema or self.conn.schema conn_type = self.conn.conn_type if conn_type in ("redshift", "postgres", "mysql", "mssql"): odbc_connector = "" @@ -227,11 +241,11 @@ def make_connection_string(self) -> str: odbc_connector = "mysql" else: odbc_connector = "mssql+pyodbc" - uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{self.conn.schema}" # noqa + uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{schema}" # noqa elif conn_type == "snowflake": - uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{self.conn.schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa + uri_string = f"snowflake://{self.conn.login}:{self.conn.password}@{self.conn.extra_dejson['extra__snowflake__account']}.{self.conn.extra_dejson['extra__snowflake__region']}/{self.conn.extra_dejson['extra__snowflake__database']}/{schema}?warehouse={self.conn.extra_dejson['extra__snowflake__warehouse']}&role={self.conn.extra_dejson['extra__snowflake__role']}" # noqa elif conn_type == "gcpbigquery": - uri_string = f"{self.conn.host}{self.conn.schema}" + uri_string = f"{self.conn.host}{schema}" elif conn_type == "sqlite": uri_string = f"sqlite:///{self.conn.host}" # TODO: Add Athena and Trino support if possible diff --git a/tests/operators/test_great_expectations.py b/tests/operators/test_great_expectations.py index 6086421..aa95685 100644 --- a/tests/operators/test_great_expectations.py +++ b/tests/operators/test_great_expectations.py @@ -862,6 +862,69 @@ def test_great_expectations_operator__make_connection_string_sqlite(): assert operator.make_connection_string() == test_conn_str +def test_great_expectations_operator__make_connection_string_schema_parameter(): + test_conn_str = ( + "snowflake://user:password@account.region-east-1/database/test_schema_parameter?warehouse=warehouse&role=role" + ) + operator = GreatExpectationsOperator( + task_id="task_id", + data_context_config=in_memory_data_context_config, + data_asset_name="test_schema.test_table", + conn_id="snowflake_default", + expectation_suite_name="suite", + schema="test_schema_parameter", + ) + operator.conn = Connection( + conn_id="snowflake_default", + conn_type="snowflake", + host="connection", + login="user", + password="password", + schema="schema", + port=5439, + extra={ + "extra__snowflake__role": "role", + "extra__snowflake__warehouse": "warehouse", + "extra__snowflake__database": "database", + "extra__snowflake__region": "region-east-1", + "extra__snowflake__account": "account", + }, + ) + operator.conn_type = operator.conn.conn_type + assert operator.make_connection_string() == test_conn_str + + +def test_great_expectations_operator__make_connection_string_data_asset_name_schema_parse(): + test_conn_str = ( + "snowflake://user:password@account.region-east-1/database/test_schema?warehouse=warehouse&role=role" + ) + operator = GreatExpectationsOperator( + task_id="task_id", + data_context_config=in_memory_data_context_config, + data_asset_name="test_schema.test_table", + conn_id="snowflake_default", + expectation_suite_name="suite", + ) + operator.conn = Connection( + conn_id="snowflake_default", + conn_type="snowflake", + host="connection", + login="user", + password="password", + port=5439, + extra={ + "extra__snowflake__role": "role", + "extra__snowflake__warehouse": "warehouse", + "extra__snowflake__database": "database", + "extra__snowflake__region": "region-east-1", + "extra__snowflake__account": "account", + }, + ) + operator.conn_type = operator.conn.conn_type + assert operator.make_connection_string() == test_conn_str + assert operator.data_asset_name == "test_table" + + def test_great_expectations_operator__make_connection_string_raise_error(): operator = GreatExpectationsOperator( task_id="task_id",