diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index e5d176b56c..36b7e59459 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -44,14 +44,54 @@ logger = get_logger(__name__) +# Subset of dbt event types that represent errors/failures. +# Used (together with node status lifecycle events like NodeStart/NodeCompiling/ +# NodeExecuting/NodeFinished) to build _DBT_EVENT_ALLOWLIST, which controls which +# events are surfaced in consumer tasks. +# Source: https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/events/types.py +_DBT_ERROR_EVENTS_TYPES = frozenset( + { + "InvalidOptionYAML", + "LogDbtProjectError", + "LogDbtProfileError", + "InputFileDiffError", + "PartialParsingErrorProcessingFile", + "PartialParsingError", + "ParsedFileLoadFailed", + "ParseInlineNodeError", + "RunningOperationCaughtError", + "SQLRunnerException", + "RunningOperationUncaughtError", + "CatchableExceptionOnRun", + "InternalErrorOnRun", + "GenericExceptionOnRun", + "NodeConnectionReleaseError", + "MainEncounteredError", + "RunResultFailure", + "RunResultError", + "CheckNodeTestFailure", + "LogSkipBecauseError", + "SendEventFailure", + "FlushEventsFailure", + "TrackingInitializeFailure", + "ArtifactUploadError", + } +) + +_DBT_NODE_STATUS_EVENT_TYPES = frozenset({"NodeStart", "NodeCompiling", "NodeExecuting", "NodeFinished"}) + +_DBT_EVENT_ALLOWLIST = _DBT_ERROR_EVENTS_TYPES | _DBT_NODE_STATUS_EVENT_TYPES + def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMsg) -> None: logger.debug("dbt_log: %s", dbt_log) - sensitive_words = ["fail", "error"] if isinstance(dbt_log, dict): # Subprocess data = dbt_log.get("data", {}) info = dbt_log.get("info", {}) + event_name = info.get("name") + if event_name not in _DBT_EVENT_ALLOWLIST: + return None node_info = data.get("node_info") status = node_info.get("node_status") if node_info else None unique_id = node_info.get("unique_id") if node_info else None @@ -59,6 +99,9 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMs finish_time = node_info.get("node_finished_at") if node_info else None msg = data.get("msg") or info.get("msg") or None else: # Runner + event_name = getattr(getattr(dbt_log, "info", None), "name", None) + if event_name not in _DBT_EVENT_ALLOWLIST: + return None node_info = getattr(dbt_log.data, "node_info", None) unique_id = getattr(node_info, "unique_id") if node_info else None status = getattr(node_info, "node_status", None) if node_info else None @@ -66,15 +109,6 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMs finish_time = getattr(node_info, "node_finished_at", None) if node_info else None msg = getattr(dbt_log.info, "msg", None) - # Special case when node status is the string "None"; only process messages that contain an error or fail word - if status in ["None"] and msg is not None: - # Check if there is error log message - for sensitive_word in sensitive_words: - if sensitive_word in msg.lower(): - break - else: - return None - if unique_id: dbt_event = { "status": status, @@ -84,11 +118,6 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMs } xcom_key = f"{unique_id.replace('.', '__')}_dbt_event" - # Avoid redundant XCom writes (and global lock contention) by only pushing - # when the event payload has changed. - existing_event = get_xcom_val(task_instance=task_instance, key=xcom_key, task_ids=PRODUCER_WATCHER_TASK_ID) - if existing_event == dbt_event: - return None safe_xcom_push(task_instance=task_instance, key=xcom_key, value=dbt_event) diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 5f68f3ae04..00e1bcd351 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -81,17 +81,15 @@ def test_send_sigterm(mock_killpg, mock_getpgid): @pytest.mark.parametrize( - "status,context,should_push,expect_assert", + "status,context,expect_assert", [ - ("success", {"ti": MagicMock()}, True, False), - ("failed", {"ti": MagicMock()}, True, False), - ("running", {"ti": MagicMock()}, False, False), - (None, {"ti": MagicMock()}, False, False), - ("success", None, False, True), - ("failed", None, False, True), + ("success", {"ti": MagicMock()}, False), + ("failed", {"ti": MagicMock()}, False), + ("success", None, True), + ("failed", None, True), ], ) -def test_store_dbt_resource_status_from_log_param(status, context, should_push, expect_assert): +def test_store_dbt_resource_status_from_log_param(status, context, expect_assert): # Prepare log line log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}} line = json.dumps(log_line) @@ -106,13 +104,10 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, store_dbt_resource_status_from_log( line, {"context": context}, tests_per_model={}, test_results_per_model={} ) - if should_push: - mock_push.assert_called_with( - task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status - ) - assert mock_push.call_count == 2 - else: - mock_push.assert_called_once() + mock_push.assert_called_with( + task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status + ) + assert mock_push.call_count == 1 def test_store_dbt_resource_status_from_log_invalid_json(): diff --git a/tests/operators/_watcher/test_watcher_base.py b/tests/operators/_watcher/test_watcher_base.py index a56350ba2c..2892ca7b8d 100644 --- a/tests/operators/_watcher/test_watcher_base.py +++ b/tests/operators/_watcher/test_watcher_base.py @@ -57,73 +57,52 @@ class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): assert sensor.extra_context == {} @pytest.mark.parametrize( - "msg,should_push", + "event_name,should_push", [ - ("1 of 30 START sql table model bq_dev.a ......................... [RUN]", False), - ("6 of 30 FAIL creating sql view model bq_dev.stg_customers ............... [ERROR in 1.18s]", True), - ("6 of 30 ERROR creating sql view model bq_dev.stg_customers ..................... [ERROR in 1.18s]", True), + (None, False), + ("LogStartLine", False), + ("NodeFinished", True), + ("NodeStart", True), ], ) - def test_process_dbt_log_event_sensitive_words(self, msg, should_push): + def test_process_dbt_log_event_only_pushes_when_event_in_allowlist(self, event_name, should_push): + """Only dbt events whose names are in _DBT_EVENT_ALLOWLIST are pushed to XCom.""" task_instance = Mock() dbt_log = { "data": { "node_info": { "unique_id": "model.test.my_model", - "node_status": "None", + "node_status": "success", "node_started_at": "2024-01-01T00:00:00", "node_finished_at": "2024-01-01T00:01:00", }, - "msg": msg, + "msg": "model finished", }, - "info": {}, + "info": {"name": event_name} if event_name is not None else {}, } with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: - _process_dbt_log_event(task_instance, dbt_log) if should_push: mock_push.assert_called_once() + call_kwargs = mock_push.call_args.kwargs + assert call_kwargs["key"] == "model__test__my_model_dbt_event" + assert call_kwargs["value"]["status"] == "success" + assert call_kwargs["value"]["msg"] == "model finished" else: mock_push.assert_not_called() - def test_process_dbt_log_event_skips_duplicate_event(self): + def test_process_dbt_log_event_skips_when_no_unique_id(self): + """Events with no node_info.unique_id are not pushed.""" task_instance = Mock() dbt_log = { - "data": { - "node_info": { - "unique_id": "model.test.my_model", - "node_status": "success", - "node_started_at": "2024-01-01T00:00:00", - "node_finished_at": "2024-01-01T00:01:00", - }, - "msg": "model finished", - }, - "info": {}, + "data": {"node_info": {}, "msg": "some log"}, + "info": {"name": "NodeFinished"}, } - duplicate_event = { - "status": "success", - "start_time": "2024-01-01T00:00:00", - "finish_time": "2024-01-01T00:01:00", - "msg": "model finished", - } - - with ( - patch( - "cosmos.operators._watcher.base.get_xcom_val", - return_value=duplicate_event, - ), - patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push, - patch( - "cosmos.operators._watcher.base._iso_to_string", - side_effect=lambda x: x, - ), - ): - result = _process_dbt_log_event(task_instance, dbt_log) - - assert result is None + with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: + _process_dbt_log_event(task_instance, dbt_log) mock_push.assert_not_called()