Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def __init__(
self.project_dir = project_dir
self.profile_config = profile_config
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile # type: ignore
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]
self.extra_context = extra_context or {}
self.configuration: dict[str, Any] = {}
self.dbt_kwargs = dbt_kwargs or {}
Expand Down Expand Up @@ -103,6 +100,8 @@ def __init__(
self.async_context["profile_type"] = self.profile_config.get_profile_type()
self.async_context["async_operator"] = BigQueryInsertJobOperator
self.compiled_sql = ""
self.gcp_project = ""
self.dataset = ""

@property
def base_cmd(self) -> list[str]:
Expand Down Expand Up @@ -145,10 +144,10 @@ def execute(self, context: Context, **kwargs: Any) -> None:
super().execute(context=context)
else:
self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context)
self._store_compiled_sql(context=context)
self._store_template_fields(context=context)

@provide_session
def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION) -> None:
def _store_template_fields(self, context: Context, session: Session = NEW_SESSION) -> None:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskinstance import TaskInstance

Expand All @@ -159,6 +158,10 @@ def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION)
self.log.debug("Executed SQL is: %s", sql)
self.compiled_sql = sql

profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# need to refresh the rendered task field record in the db because Airflow only does this
# before executing the task, not after
ti = context["ti"]
Expand Down Expand Up @@ -188,5 +191,5 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
job_id = super().execute_complete(context=context, event=event)
self.log.info("Configuration is %s", str(self.configuration))
self._store_compiled_sql(context=context)
self._store_template_fields(context=context)
return job_id
10 changes: 6 additions & 4 deletions tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def test_dbt_run_airflow_async_bigquery_operator_init(profile_config_mock):
assert operator.project_dir == "/path/to/project"
assert operator.profile_config == profile_config_mock
assert operator.gcp_conn_id == "google_cloud_default"
assert operator.gcp_project == "test_project"
assert operator.dataset == "test_dataset"


def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock):
Expand Down Expand Up @@ -134,15 +132,19 @@ def test_store_compiled_sql(mock_rendered_ti, mock_get_remote_sql, profile_confi
mock_task_instance.task = operator
mock_context = {"ti": mock_task_instance}

operator._store_compiled_sql(mock_context, session=mock_session)
operator._store_template_fields(mock_context, session=mock_session)
# check if gcp_project and dataset are set after the tasks gets executed

Comment thread
joppevos marked this conversation as resolved.
assert operator.compiled_sql == "SELECT * FROM test_table;"
assert operator.dataset == "test_dataset"
assert operator.gcp_project == "test_project"

mock_rendered_ti.assert_called_once()
mock_session.add.assert_called_once()
mock_session.query().filter().delete.assert_called_once()


@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_compiled_sql")
@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_template_fields")
def test_execute_complete(mock_store_sql, profile_config_mock):
mock_context = Mock()
mock_event = {"job_id": "test_job"}
Expand Down