Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,71 @@

logger = get_logger(__name__)

# Subset of dbt event types that represent errors/failures.
# Used (together with node status lifecycle events like NodeStart/NodeCompiling/
# NodeExecuting/NodeFinished) to build _DBT_EVENT_ALLOWLIST, which controls which
# events are surfaced in consumer tasks.
# Source: https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/events/types.py
_DBT_ERROR_EVENTS_TYPES = frozenset(
{
"InvalidOptionYAML",
"LogDbtProjectError",
"LogDbtProfileError",
"InputFileDiffError",
"PartialParsingErrorProcessingFile",
"PartialParsingError",
"ParsedFileLoadFailed",
"ParseInlineNodeError",
"RunningOperationCaughtError",
"SQLRunnerException",
"RunningOperationUncaughtError",
"CatchableExceptionOnRun",
"InternalErrorOnRun",
"GenericExceptionOnRun",
"NodeConnectionReleaseError",
"MainEncounteredError",
"RunResultFailure",
"RunResultError",
"CheckNodeTestFailure",
"LogSkipBecauseError",
"SendEventFailure",
"FlushEventsFailure",
"TrackingInitializeFailure",
"ArtifactUploadError",
Comment thread
pankajastro marked this conversation as resolved.
}
)

_DBT_NODE_STATUS_EVENT_TYPES = frozenset({"NodeStart", "NodeCompiling", "NodeExecuting", "NodeFinished"})

_DBT_EVENT_ALLOWLIST = _DBT_ERROR_EVENTS_TYPES | _DBT_NODE_STATUS_EVENT_TYPES


def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMsg) -> None:
logger.debug("dbt_log: %s", dbt_log)
sensitive_words = ["fail", "error"]
if isinstance(dbt_log, dict): # Subprocess
data = dbt_log.get("data", {})
info = dbt_log.get("info", {})

event_name = info.get("name")
if event_name not in _DBT_EVENT_ALLOWLIST:
return None
node_info = data.get("node_info")
status = node_info.get("node_status") if node_info else None
unique_id = node_info.get("unique_id") if node_info else None
start_time = node_info.get("node_started_at") if node_info else None
finish_time = node_info.get("node_finished_at") if node_info else None
msg = data.get("msg") or info.get("msg") or None
else: # Runner
Comment thread
tatiana marked this conversation as resolved.
event_name = getattr(getattr(dbt_log, "info", None), "name", None)
if event_name not in _DBT_EVENT_ALLOWLIST:
return None
node_info = getattr(dbt_log.data, "node_info", None)
unique_id = getattr(node_info, "unique_id") if node_info else None
status = getattr(node_info, "node_status", None) if node_info else None
Comment thread
pankajastro marked this conversation as resolved.
start_time = getattr(node_info, "node_started_at", None) if node_info else None
finish_time = getattr(node_info, "node_finished_at", None) if node_info else None
Comment thread
pankajastro marked this conversation as resolved.
msg = getattr(dbt_log.info, "msg", None)

# Special case when node status is the string "None"; only process messages that contain an error or fail word
if status in ["None"] and msg is not None:
# Check if there is error log message
for sensitive_word in sensitive_words:
if sensitive_word in msg.lower():
break
else:
return None

if unique_id:
dbt_event = {
"status": status,
Expand All @@ -84,11 +118,6 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMs
}

xcom_key = f"{unique_id.replace('.', '__')}_dbt_event"
Comment thread
pankajastro marked this conversation as resolved.
# Avoid redundant XCom writes (and global lock contention) by only pushing
# when the event payload has changed.
existing_event = get_xcom_val(task_instance=task_instance, key=xcom_key, task_ids=PRODUCER_WATCHER_TASK_ID)
if existing_event == dbt_event:
return None
Comment thread
tatiana marked this conversation as resolved.
safe_xcom_push(task_instance=task_instance, key=xcom_key, value=dbt_event)
Comment thread
pankajastro marked this conversation as resolved.


Expand Down
25 changes: 10 additions & 15 deletions tests/hooks/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,15 @@ def test_send_sigterm(mock_killpg, mock_getpgid):


@pytest.mark.parametrize(
"status,context,should_push,expect_assert",
"status,context,expect_assert",
[
("success", {"ti": MagicMock()}, True, False),
("failed", {"ti": MagicMock()}, True, False),
("running", {"ti": MagicMock()}, False, False),
(None, {"ti": MagicMock()}, False, False),
("success", None, False, True),
("failed", None, False, True),
("success", {"ti": MagicMock()}, False),
("failed", {"ti": MagicMock()}, False),
("success", None, True),
("failed", None, True),
],
)
def test_store_dbt_resource_status_from_log_param(status, context, should_push, expect_assert):
def test_store_dbt_resource_status_from_log_param(status, context, expect_assert):
# Prepare log line
log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}}
line = json.dumps(log_line)
Expand All @@ -106,13 +104,10 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push,
store_dbt_resource_status_from_log(
line, {"context": context}, tests_per_model={}, test_results_per_model={}
)
if should_push:
mock_push.assert_called_with(
task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status
)
assert mock_push.call_count == 2
else:
mock_push.assert_called_once()
mock_push.assert_called_with(
task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status
)
assert mock_push.call_count == 1


Comment thread
pankajastro marked this conversation as resolved.
def test_store_dbt_resource_status_from_log_invalid_json():
Expand Down
61 changes: 20 additions & 41 deletions tests/operators/_watcher/test_watcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,73 +57,52 @@ class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator):
assert sensor.extra_context == {}

@pytest.mark.parametrize(
"msg,should_push",
"event_name,should_push",
[
("1 of 30 START sql table model bq_dev.a ......................... [RUN]", False),
("6 of 30 FAIL creating sql view model bq_dev.stg_customers ............... [ERROR in 1.18s]", True),
("6 of 30 ERROR creating sql view model bq_dev.stg_customers ..................... [ERROR in 1.18s]", True),
(None, False),
("LogStartLine", False),
("NodeFinished", True),
("NodeStart", True),
],
)
def test_process_dbt_log_event_sensitive_words(self, msg, should_push):
def test_process_dbt_log_event_only_pushes_when_event_in_allowlist(self, event_name, should_push):
"""Only dbt events whose names are in _DBT_EVENT_ALLOWLIST are pushed to XCom."""
task_instance = Mock()

dbt_log = {
"data": {
"node_info": {
"unique_id": "model.test.my_model",
"node_status": "None",
"node_status": "success",
"node_started_at": "2024-01-01T00:00:00",
"node_finished_at": "2024-01-01T00:01:00",
},
Comment thread
pankajastro marked this conversation as resolved.
"msg": msg,
"msg": "model finished",
},
"info": {},
"info": {"name": event_name} if event_name is not None else {},
}

with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:

_process_dbt_log_event(task_instance, dbt_log)

if should_push:
mock_push.assert_called_once()
call_kwargs = mock_push.call_args.kwargs
assert call_kwargs["key"] == "model__test__my_model_dbt_event"
assert call_kwargs["value"]["status"] == "success"
assert call_kwargs["value"]["msg"] == "model finished"
else:
mock_push.assert_not_called()

def test_process_dbt_log_event_skips_duplicate_event(self):
def test_process_dbt_log_event_skips_when_no_unique_id(self):
"""Events with no node_info.unique_id are not pushed."""
task_instance = Mock()

dbt_log = {
"data": {
"node_info": {
"unique_id": "model.test.my_model",
"node_status": "success",
"node_started_at": "2024-01-01T00:00:00",
"node_finished_at": "2024-01-01T00:01:00",
},
"msg": "model finished",
},
"info": {},
"data": {"node_info": {}, "msg": "some log"},
"info": {"name": "NodeFinished"},
}

duplicate_event = {
"status": "success",
"start_time": "2024-01-01T00:00:00",
"finish_time": "2024-01-01T00:01:00",
"msg": "model finished",
}

with (
patch(
"cosmos.operators._watcher.base.get_xcom_val",
return_value=duplicate_event,
),
patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push,
patch(
"cosmos.operators._watcher.base._iso_to_string",
side_effect=lambda x: x,
),
):
result = _process_dbt_log_event(task_instance, dbt_log)

assert result is None
with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:
_process_dbt_log_event(task_instance, dbt_log)
mock_push.assert_not_called()
Loading