Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 111 additions & 10 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from abc import ABC
from os import PathLike
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence

import kubernetes.client as k8s
from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import (
Expand Down Expand Up @@ -189,6 +189,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:


class DbtTestWarningHandler(KubernetesPodOperatorCallback): # type: ignore[misc]
"""
This handler can detect warnings from:
1. Regular dbt tests (using the standard "Done. PASS=X WARN=Y" pattern)
2. Source freshness tests (using "WARN freshness of..." pattern)
"""

def __init__(
self,
on_warning_callback: Callable[..., Any],
Expand Down Expand Up @@ -224,21 +230,116 @@ def on_pod_completion( # type: ignore[override]
self.operator.log.warning(f"Cannot handle dbt warnings for task of type {type(task)}.")
return

logs = [log.decode("utf-8") for log in task.pod_manager.read_pod_logs(pod, "base") if log.decode("utf-8") != ""]

warn_count_pattern = re.compile(r"Done\. (?:\w+=\d+ )*WARN=(\d+)(?: \w+=\d+)*")
warn_count = warn_count_pattern.search("\n".join(logs))
if not warn_count:
# Get the logs from the pod
logs = []
for log in task.pod_manager.read_pod_logs(pod, "base"):
decoded_log = log.decode("utf-8")
if decoded_log != "":
logs.append(decoded_log)

logs_text = "\n".join(logs)

# Check for warnings
warning_detected = False
test_names, test_results = [], []
if isinstance(task, DbtTestKubernetesOperator):
warn_count = self._detect_standard_warnings(logs_text)
if warn_count:
self.operator.log.info(f"Detected {warn_count} warnings using standard pattern")
# test_names, test_results = self._extract_standard_log_issues(logs_text)
warning_detected = True
elif isinstance(task, DbtSourceKubernetesOperator):
source_freshness_warnings = self._detect_source_freshness_warnings(logs_text)
if source_freshness_warnings:
self.operator.log.info(f"Detected {len(source_freshness_warnings)} source freshness warnings")
# test_names = [w["name"] for w in source_freshness_warnings]
# test_results = [w["status"] for w in source_freshness_warnings]
warning_detected = True

if not warning_detected:
self.operator.log.warning(
"Failed to scrape warning count from the pod logs."
"Potential warning callbacks could not be triggered."
)
return

if int(warn_count.group(1)) > 0:
test_names, test_results = extract_log_issues(logs)
context_merge(self.context, test_names=test_names, test_results=test_results)
self.on_warning_callback(self.context)
test_names, test_results = extract_log_issues(logs)
context_merge(self.context, test_names=test_names, test_results=test_results)
self.on_warning_callback(self.context)

def _detect_standard_warnings(self, log_text: str) -> Optional[int]:
"""
Detect warnings using the standard dbt summary pattern.

Pattern: "Done. PASS=X WARN=Y ERROR=Z SKIP=W"

:param log_text: Complete log text from the pod
:return: Number of warnings detected, or None if pattern not found
"""
warn_count_pattern = re.compile(r"Done\. (?:\w+=\d+ )*WARN=(\d+)(?: \w+=\d+)*")
match = warn_count_pattern.search(log_text)

if match:
return int(match.group(1))
return None

def _detect_source_freshness_warnings(self, log_text: str) -> List[Dict[str, Any]]:
"""
Detect source freshness warnings from dbt logs.

Pattern examples:
- "15:49:21 1 of 1 WARN freshness of auction_net.auction_net_raw ... [WARN in 0.90s]"
- "WARN freshness of source_name.table_name"

:param log_text: Complete log text from the pod
:return: List of warning dictionaries
"""
warnings = []

# Primary pattern for source freshness warnings
# Matches: "HH:MM:SS X of Y WARN freshness of source.table ... [WARN in Xs]"
freshness_pattern = re.compile(
r"(\d{2}:\d{2}:\d{2})\s+" # timestamp
r"\d+\s+of\s+\d+\s+" # "X of Y"
r"WARN\s+freshness\s+of\s+" # "WARN freshness of"
r"([^\s]+)" # source name
r".*?\[WARN\s+in\s+([\d.]+)s\]" # execution time
)

for match in freshness_pattern.finditer(log_text):
timestamp = match.group(1)
source_name = match.group(2)
execution_time = match.group(3)

warnings.append(
{
"name": f"source_freshness_{source_name}",
"status": "WARN",
"type": "source_freshness",
"source": source_name,
"timestamp": timestamp,
"execution_time": execution_time,
}
)

# Secondary pattern for simpler source freshness warnings
# Matches: "WARN freshness of source_name"
simple_freshness_pattern = re.compile(r"WARN\s+freshness\s+of\s+([^\s]+)")

for match in simple_freshness_pattern.finditer(log_text):
source_name = match.group(1)
# Only add if not already captured by primary pattern
if not any(w["source"] == source_name for w in warnings):
warnings.append(
{
"name": f"source_freshness_{source_name}",
"status": "WARN",
"type": "source_freshness",
"source": source_name,
}
)

return warnings


class DbtWarningKubernetesOperator(DbtKubernetesBaseOperator, ABC):
Expand Down
207 changes: 206 additions & 1 deletion tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,211 @@ def read_pod_logs(self, pod, container):
return (log.encode("utf-8") for log in self.log_string.split("\n"))


def create_test_handler():
"""Helper function to create a test handler with mocks"""
mock_callback = Mock()
mock_operator = Mock()
mock_context = {"task_instance": Mock()}
handler = DbtTestWarningHandler(on_warning_callback=mock_callback, operator=mock_operator, context=mock_context)
return handler, mock_callback, mock_operator, mock_context


@pytest.mark.parametrize(
("log_text", "expected_warn_count"),
[
# Standard warning with summary
(
"""
19:48:25 Concurrency: 4 threads (target='target')
19:48:27 1 of 2 WARN dbt_utils_accepted_range ..................... [WARN 117 in 1.83s]
19:48:27 2 of 2 PASS unique_table__uuid ................................................ [PASS in 1.85s]
19:48:27 Done. PASS=1 WARN=1 ERROR=0 SKIP=0 TOTAL=2
""",
1,
),
# Multiple warnings
(
"""
19:48:25 Concurrency: 4 threads (target='target')
19:48:27 1 of 3 WARN test_one ..................... [WARN in 1.83s]
19:48:27 2 of 3 WARN test_two ..................... [WARN in 1.85s]
19:48:27 3 of 3 PASS test_three ................... [PASS in 1.85s]
19:48:27 Done. PASS=1 WARN=2 ERROR=0 SKIP=0 TOTAL=3
""",
2,
),
# No warnings
(
"""
19:48:25 Concurrency: 4 threads (target='target')
19:48:27 1 of 2 PASS test_one ..................... [PASS in 1.83s]
19:48:27 2 of 2 PASS test_two ..................... [PASS in 1.85s]
19:48:27 Done. PASS=2 WARN=0 ERROR=0 SKIP=0 TOTAL=2
""",
0,
),
# No summary (like source freshness)
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:20 1 of 1 START freshness of auction_net.raw .......................... [RUN]
15:49:21 1 of 1 WARN freshness of auction_net.raw ........................... [WARN in 0.90s]
15:49:21 Done.
""",
None,
),
],
)
def test_detect_standard_warnings(log_text, expected_warn_count):
"""Test detection of standard dbt test warnings"""
handler, _, _, _ = create_test_handler()
warn_count = handler._detect_standard_warnings(log_text)
assert warn_count == expected_warn_count


@pytest.mark.parametrize(
("log_text", "expected_warning_count", "expected_sources"),
[
# Single source freshness warning
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:20 1 of 1 START freshness of auction_net.auction_net_raw .......................... [RUN]
15:49:21 1 of 1 WARN freshness of auction_net.auction_net_raw ........................... [WARN in 0.90s]
15:49:21 Done.
""",
1,
["auction_net.auction_net_raw"],
),
# Multiple source freshness warnings
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:20 1 of 3 START freshness of source1.table1 .......................... [RUN]
15:49:21 1 of 3 WARN freshness of source1.table1 ........................... [WARN in 0.90s]
15:49:21 2 of 3 START freshness of source2.table2 .......................... [RUN]
15:49:22 2 of 3 WARN freshness of source2.table2 ........................... [WARN in 1.20s]
15:49:22 3 of 3 START freshness of source3.table3 .......................... [RUN]
15:49:23 3 of 3 PASS freshness of source3.table3 ........................... [PASS in 0.45s]
15:49:23 Done.
""",
2,
["source1.table1", "source2.table2"],
),
# No source freshness warnings - all pass
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:20 1 of 1 START freshness of auction_net.raw .......................... [RUN]
15:49:21 1 of 1 PASS freshness of auction_net.raw ........................... [PASS in 0.90s]
15:49:21 Done.
""",
0,
[],
),
# Empty source freshness log
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:21 Done.
""",
0,
[],
),
],
)
def test_detect_source_freshness_warnings(log_text, expected_warning_count, expected_sources):
"""Test detection of source freshness warnings"""
handler, _, _, _ = create_test_handler()
warnings = handler._detect_source_freshness_warnings(log_text)
assert len(warnings) == expected_warning_count

if expected_sources:
actual_sources = [w["source"] for w in warnings]
for expected_source in expected_sources:
assert expected_source in actual_sources


@pytest.mark.parametrize(
("log_text",),
[
# Source freshness log with single warning
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:18 Concurrency: 2 threads (target='default')
15:49:20 1 of 1 START freshness of auction_net.auction_net_raw .......................... [RUN]
15:49:21 1 of 1 WARN freshness of auction_net.auction_net_raw ........................... [WARN in 0.90s]
15:49:21 Finished running 1 source in 0 hours 0 minutes and 3.27 seconds (3.27s).
15:49:21 Done.
""",
),
# Source freshness log with multiple warnings
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:18 Concurrency: 2 threads (target='default')
15:49:20 1 of 3 START freshness of source1.table1 .......................... [RUN]
15:49:21 1 of 3 WARN freshness of source1.table1 ........................... [WARN in 0.90s]
15:49:21 2 of 3 START freshness of source2.table2 .......................... [RUN]
15:49:22 2 of 3 WARN freshness of source2.table2 ........................... [WARN in 1.20s]
15:49:22 3 of 3 START freshness of source3.table3 .......................... [RUN]
15:49:23 3 of 3 PASS freshness of source3.table3 ........................... [PASS in 0.45s]
15:49:23 Finished running 3 sources in 0 hours 0 minutes and 5.12 seconds (5.12s).
15:49:23 Done.
""",
),
# Source freshness log with no warnings
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:18 Concurrency: 2 threads (target='default')
15:49:20 1 of 2 START freshness of auction_net.raw .......................... [RUN]
15:49:21 1 of 2 PASS freshness of auction_net.raw ........................... [PASS in 0.90s]
15:49:21 2 of 2 START freshness of another_source.table .......................... [RUN]
15:49:22 2 of 2 PASS freshness of another_source.table ........................... [PASS in 0.45s]
15:49:22 Finished running 2 sources in 0 hours 0 minutes and 4.32 seconds (4.32s).
15:49:22 Done.
""",
),
# Source freshness log with mixed results
(
"""
15:49:18 Found 205 models, 27 data tests, 66 sources, 639 macros
15:49:18 Concurrency: 2 threads (target='default')
15:49:20 1 of 4 START freshness of source_a.table_a .......................... [RUN]
15:49:21 1 of 4 PASS freshness of source_a.table_a ........................... [PASS in 0.90s]
15:49:21 2 of 4 START freshness of source_b.table_b .......................... [RUN]
15:49:22 2 of 4 WARN freshness of source_b.table_b ........................... [WARN in 1.10s]
15:49:22 3 of 4 START freshness of source_c.table_c .......................... [RUN]
15:49:23 3 of 4 PASS freshness of source_c.table_c ........................... [PASS in 0.45s]
15:49:23 4 of 4 START freshness of source_d.table_d .......................... [RUN]
15:49:24 4 of 4 WARN freshness of source_d.table_d ........................... [WARN in 0.78s]
15:49:24 Finished running 4 sources in 0 hours 0 minutes and 6.45 seconds (6.45s).
15:49:24 Done.
""",
),
],
)
def test_source_freshness_log_formats(log_text):
"""Test various source freshness log formats to ensure parsing works correctly"""
handler, _, _, _ = create_test_handler()
warnings = handler._detect_source_freshness_warnings(log_text)

# Count expected warnings by counting "WARN freshness of" occurrences
expected_warnings = log_text.count("WARN freshness of")
assert len(warnings) == expected_warnings

# Verify each warning has required fields
for warning in warnings:
assert "name" in warning
assert "status" in warning
assert warning["status"] == "WARN"
assert "type" in warning
assert warning["type"] == "source_freshness"
assert "source" in warning


@pytest.mark.parametrize(
("log_string", "should_call"),
(
Expand All @@ -210,7 +415,7 @@ def read_pod_logs(self, pod, container):
19:48:25
19:48:25 1 of 2 START test dbt_utils_accepted_range_table_col__12__0 ................... [RUN]
19:48:25 2 of 2 START test unique_table__uuid .......................................... [RUN]
19:48:27 1 of 2 WARN 252 dbt_utils_accepted_range_table_col__12__0 ..................... [WARN 117 in 1.83s]
19:48:27 1 of 2 WARN dbt_utils_accepted_range_table_col__12__0 ..................... [WARN in 1.83s]
19:48:27 2 of 2 PASS unique_table__uuid ................................................ [PASS in 1.85s]
19:48:27
19:48:27 Finished running 2 tests, 1 hook in 0 hours 0 minutes and 12.86 seconds (12.86s).
Expand Down