Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b6a4ef4
add defferred logging callback
johnhoran Jan 9, 2026
22744c0
log callbacks
johnhoran Jan 9, 2026
c3110ad
log callbacks
johnhoran Jan 9, 2026
593a2f6
use sync logging
johnhoran Jan 9, 2026
413c2ec
use sync logging
johnhoran Jan 9, 2026
326ba80
log last log time info
johnhoran Jan 12, 2026
e955453
log status
johnhoran Jan 12, 2026
82cd175
check type
johnhoran Jan 12, 2026
488d75e
set post_termination_timeout
johnhoran Jan 12, 2026
26028d5
cleanup
johnhoran Jan 12, 2026
69d98fc
remove left over code
johnhoran Jan 12, 2026
4f7efc8
fix tests
johnhoran Jan 15, 2026
66a2f63
tests green
johnhoran Jan 16, 2026
36cbfc8
don't follow logs
johnhoran Jan 16, 2026
33107bd
check async logs
johnhoran Jan 19, 2026
2a1656b
test logging
johnhoran Jan 19, 2026
32afc4f
cleanup
johnhoran Jan 19, 2026
8d4a6c5
resolve failure
johnhoran Jan 19, 2026
1771384
format
johnhoran Jan 19, 2026
f78339c
fix doc
johnhoran Jan 19, 2026
690b396
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Jan 20, 2026
f3cf176
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Jan 26, 2026
a7b4a06
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Jan 26, 2026
3faeb21
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Jan 29, 2026
c9e2d29
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Jan 30, 2026
9c4f672
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Feb 17, 2026
462e8e7
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Feb 17, 2026
1b45bb5
Merge branch 'main' into kpo_deferred_logging_callback
johnhoran Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.
from __future__ import annotations

import asyncio
import logging
import secrets
import string
from functools import cache
from functools import cache, wraps
from typing import TYPE_CHECKING

import pendulum
Expand All @@ -32,6 +33,7 @@

from airflow.configuration import conf
from airflow.providers.cncf.kubernetes.backcompat import get_logical_date_key
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode
from airflow.providers.common.compat.sdk import AirflowException

if TYPE_CHECKING:
Expand Down Expand Up @@ -211,3 +213,15 @@ def annotations_for_logging_task_metadata(annotation_set):
else:
annotations_for_logging = "<omitted>"
return annotations_for_logging


def serializable_callback(f):
"""Convert async callback so it can run in sync or async mode."""

@wraps(f)
def wrapper(*args, mode: str, **kwargs):
if mode == ExecutionMode.ASYNC:
return f(*args, mode=mode, **kwargs)
return asyncio.run(f(*args, mode=mode, **kwargs))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you confident that def serialiable_callback will not be called from within a running event loop? Since this is a helper we cannot guarantee that it will not be invoked in a context without a pre-existing event loop and this could result ina RuntimeError. I believe it would be better to not special-case ASYNC mode and let the wrapper return the function as is regardless of ExecutionMode without invoking asyncio.run(). Unless you have a strong reason to do it this way.

Copy link
Copy Markdown
Contributor Author

@johnhoran johnhoran Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the idea was to define a helper function that could be used in a callback like

class AsyncCallback(KubernetesPodOperatorCallback):
    @staticmethod
    @serializable_callback
    async def progress_callback(
        *, line: str, **kwargs
    ) -> None:
       ...

and that the same callback could be used in the triggerer or in the operator, which is invoked when the triggerer hands back to the operator if there are remaining logs in the pod that haven't been processed. Ideally then the callback should be written in async format, as it would be blocking in the callback if it wasn't, though that obviously depends on what the callback is doing.
Invocations of these callbacks are in the operator/triggerer, I don't know of anywhere else they are used.

Copy link
Copy Markdown
Contributor

@SameerMesiah97 SameerMesiah97 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the invocation of @serializable_callback is not being guarded and you are relying on the user correctly inferring the intent of the function? I agree with your motivation for introducing this helper but we cannot guarantee that it will not be called within an event loop in the operator. Is there any reason why the operator cannot do something like this:

asyncio.run(callback(...))

If that is not feasible, I believe at the very minimum:

  1. The possibility of encountering RuntimeError should be documented very clearly (docstring and comment)
  2. The RuntimeError should be caught in a try/except block with a more informative error message.
    Below is a suggested implementation if you are still intent on keeping 2 separate modes:
def serializable_callback(f):
    """
    Convert async callback so it can run in sync or async mode.

    In ASYNC mode (e.g. triggerer), the callback is expected to be awaited
    by the caller. In SYNC mode (e.g. operator fallback), the callback is
    executed via asyncio.run(); callers should ensure this is only used
    when no event loop is already running.
    """

    @wraps(f)
    def wrapper(*args, mode: str, **kwargs):
        if mode == ExecutionMode.ASYNC:
            return f(*args, mode=mode, **kwargs)

        # SYNC mode owns the event loop; calling this while a loop is already
        # running is a hard error and indicates a misclassified execution context.
        try:
            return asyncio.run(f(*args, mode=mode, **kwargs))
        except RuntimeError as e:
            raise RuntimeError(
                "Cannot call serializable_callback in SYNC mode while an event "
                "loop is running. Use ExecutionMode.ASYNC and await the callback "
                "instead."
            ) from e

    return wrapper

This will immediately inform the user of the reason for the RuntimeError and mitigate against further unsafe usage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I really understand the objection here. The invocations of this would be from the kubernetes pod operator, and sure I can't stop somebody from calling it from somewhere else, but I mean you could say that about anything...

Anyway to be honest I'd be just as happy to drop this helper from the PR.

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,13 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
pod = self.hook.get_pod(pod_name, pod_namespace)
if not pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")
self._write_logs(pod)
self.pod_manager.fetch_requested_container_logs(
pod=pod,
containers=self.container_logs,
container_name_log_prefix_enabled=self.container_name_log_prefix_enabled,
log_formatter=self.log_formatter,
post_termination_timeout=900,
)

if self.do_xcom_push:
xcom_results: list[Any | None] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import inspect
import json
import logging
import math
import os
import re
import shlex
Expand All @@ -39,7 +38,6 @@
from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.client.exceptions import ApiException
from kubernetes.stream import stream
from urllib3.exceptions import HTTPError

from airflow.configuration import conf
from airflow.providers.cncf.kubernetes import pod_generator
Expand Down Expand Up @@ -79,7 +77,7 @@
PodPhase,
)
from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, AirflowSkipException, TaskDeferred
from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, AirflowSkipException

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseHook, BaseOperator
Expand Down Expand Up @@ -912,6 +910,7 @@ def invoke_defer_method(
last_log_time=last_log_time,
logging_interval=self.logging_interval,
trigger_kwargs=self.trigger_kwargs,
callbacks=self.callbacks,
)
container_state = trigger.define_container_state(self.pod) if self.pod else None
if context and (
Expand Down Expand Up @@ -955,12 +954,17 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

follow = self.logging_interval is None
last_log_time = event.get("last_log_time")

if event["status"] in ("error", "failed", "timeout", "success"):
if self.get_logs:
self._write_logs(self.pod, follow=follow, since_time=last_log_time)
self.pod_manager.fetch_requested_container_logs(
pod=self.pod,
containers=self.container_logs,
container_name_log_prefix_enabled=self.container_name_log_prefix_enabled,
log_formatter=self.log_formatter,
since_time=last_log_time,
post_termination_timeout=900,
)

for callback in self.callbacks:
callback.on_pod_completion(
Expand All @@ -987,8 +991,6 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
)
message = event.get("stack_trace", event["message"])
raise AirflowException(message)
except TaskDeferred:
raise
finally:
self._clean(event=event, context=context, result=xcom_sidecar_output)

Expand Down Expand Up @@ -1023,33 +1025,6 @@ def _clean(self, event: dict[str, Any], result: dict | None, context: Context) -
result=result,
)

def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
try:
since_seconds = (
math.ceil((datetime.datetime.now(tz=datetime.timezone.utc) - since_time).total_seconds())
if since_time
else None
)
logs = self.client.read_namespaced_pod_log(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
container=self.base_container_name,
follow=follow,
timestamps=False,
since_seconds=since_seconds,
_preload_content=False,
)
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace").rstrip("\n")
if line:
self.log.info("[%s] logs: %s", self.base_container_name, line)
except (HTTPError, ApiException) as e:
self.log.warning(
"Reading of logs interrupted with error %r; will retry. "
"Set log level to DEBUG for traceback.",
e if not isinstance(e, ApiException) else e.reason,
)

def post_complete_action(
self, *, pod: k8s.V1Pod, remote_pod: k8s.V1Pod, context: Context, result: dict | None, **kwargs
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
import datetime
import importlib
import traceback
from collections.abc import AsyncIterator
from enum import Enum
Expand All @@ -40,6 +41,8 @@
from kubernetes_asyncio.client.models import V1Pod
from pendulum import DateTime

from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback


class ContainerState(str, Enum):
"""
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(
last_log_time: DateTime | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
callbacks: list[type[KubernetesPodOperatorCallback]] | str | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -123,6 +127,18 @@ def __init__(
self.trigger_kwargs = trigger_kwargs or {}
self._since_time = None

if callbacks and isinstance(callbacks, str):
self._callbacks = []
for cbk in callbacks.split(","):
try:
module_name, class_name = cbk.rsplit(".", 1)
clazz = getattr(importlib.import_module(module_name), class_name)
self._callbacks.append(clazz)
except (AttributeError, ModuleNotFoundError, ValueError) as e:
self.log.warning("Failed to import callback %s: %s", cbk, e)
else:
self._callbacks = callbacks or []

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreatePodTrigger arguments and classpath."""
return (
Expand All @@ -146,6 +162,12 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"last_log_time": self.last_log_time,
"logging_interval": self.logging_interval,
"trigger_kwargs": self.trigger_kwargs,
"callbacks": ",".join(
[
f"{x.__module__.split('_', 3)[3] if x.__module__.startswith('unusual_prefix_') else x.__module__}.{x.__name__}"
for x in self._callbacks
]
),
},
)

Expand All @@ -157,6 +179,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.pod_namespace,
self.poll_interval,
)

try:
state = await self._wait_for_pod_start()
if state == ContainerState.TERMINATED:
Expand Down Expand Up @@ -332,7 +355,7 @@ def hook(self) -> AsyncKubernetesHook:

@cached_property
def pod_manager(self) -> AsyncPodManager:
return AsyncPodManager(async_hook=self.hook)
return AsyncPodManager(async_hook=self.hook, callbacks=self._callbacks)

def define_container_state(self, pod: V1Pod) -> ContainerState:
if pod.status is None or pod.status.container_statuses is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Literal, cast

import kubernetes_asyncio.client as async_k8s
import pendulum
from kubernetes import client, watch
from kubernetes.client.rest import ApiException
Expand Down Expand Up @@ -680,6 +681,8 @@ def fetch_requested_container_logs(
follow_logs=False,
container_name_log_prefix_enabled: bool = True,
log_formatter: Callable[[str, str], str] | None = None,
since_time: DateTime | None = None,
post_termination_timeout: int = 120,
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the specified pod and publish it to airflow logging.
Expand All @@ -702,6 +705,8 @@ def fetch_requested_container_logs(
follow=follow_logs,
container_name_log_prefix_enabled=container_name_log_prefix_enabled,
log_formatter=log_formatter,
since_time=since_time,
post_termination_timeout=post_termination_timeout,
)
pod_logging_statuses.append(status)
return pod_logging_statuses
Expand Down Expand Up @@ -1110,31 +1115,57 @@ async def fetch_container_logs_before_current_sec(
since_seconds=(math.ceil((now - since_time).total_seconds()) if since_time else None),
)
message_to_log = None
try:
now_seconds = now.replace(microsecond=0)
for line in logs:
line_timestamp, message = parse_log_line(line)
# Skip log lines from the current second to prevent duplicate entries on the next read.
# The API only allows specifying 'since_seconds', not an exact timestamp.
if line_timestamp and line_timestamp.replace(microsecond=0) == now_seconds:
break
if line_timestamp: # detect new log line
if message_to_log is None: # first line in the log
message_to_log = message
else: # previous log line is complete
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
self.log.info("[%s] %s", container_name, message_to_log)
message_to_log = message
elif message_to_log: # continuation of the previous log line
message_to_log = f"{message_to_log}\n{message}"
finally:
# log the last line and update the last_captured_timestamp
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
self.log.info("[%s] %s", container_name, message_to_log)
async with self._hook.get_conn() as connection:
v1_api = async_k8s.CoreV1Api(connection)
try:
now_seconds = now.replace(microsecond=0)
for line in logs:
line_timestamp, message = parse_log_line(line)
# Skip log lines from the current second to prevent duplicate entries on the next read.
# The API only allows specifying 'since_seconds', not an exact timestamp.
if line_timestamp and line_timestamp.replace(microsecond=0) == now_seconds:
break
if line_timestamp: # detect new log line
if message_to_log is None: # first line in the log
message_to_log = message
else: # previous log line is complete
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
for callback in self._callbacks:
cb = callback.progress_callback(
line=message_to_log,
client=v1_api,
mode=ExecutionMode.ASYNC,
container_name=container_name,
timestamp=line_timestamp,
pod=pod,
)
if asyncio.iscoroutine(cb):
await cb
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we really be executing arbitrary user code inside the triggerer? Even though this is async, a long-running or blocking progress callback (for example calling an external API without a timeout) can still starve the triggerer’s event loop. That at least blocks the trigger executing it, and potentially other triggers handled by the same triggerer process. This feels like a fairly big design trade-off just to support progress logging, and I’m not sure it’s worth it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The specific use case I have in mind is in support of the watcher pattern that is being implement in astronomer cosmos for running DBT. astronomer/astronomer-cosmos#2207

With this you have a pod that's actually running DBT and an airflow task that is parsing kubernetes logs to extract DBT events and create xcom variables that are consumed by sensors. I think the parsing of the logs and setting xcom variables should be lightweight enough that it can be run from the triggerer. Implementing this from the triggerer was part of the original design of the defferred mode for KPO, but the implementation ran into issues so it was stripped back out.

I did wonder if I should check the callbacks and only pass the ones that have actually implemented progress_callbacks to make everything a little lighter.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key phrase here is “should be lightweight”. Can we really guarantee that? Reading the code, it doesn’t seem like there are any restrictions on what a callback could execute, so we are effectively trusting users not to do any heavy lifting in them.

One possible compromise would be to enforce a small global timeout for callbacks (e.g. a few seconds at most). However, even with that, this still sets a precedent for executing arbitrary user code in the triggerer.

I agree the motivation here is solid and I can see the value of the feature, but this crosses into a fundamental triggerer design decision, which I’m not comfortable approving or disapproving unilaterally.

@jscheffl I know you requested my review — I’d be interested in your thoughts on whether this is a precedent we’re happy to set for triggerers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There aren't any restrictions on what can be run in the triggerer either though. You can add a custom operator and run any arbitrary code you want there.

The larger concern I have is that somebody might write a callback using the non deferred mode and then switch to running in deferred mode, and then you have the triggerer calling synchronous code it wasn't designed for. I don't really have an answer to that, beyond noting that the progress_callback was broken from first implementation until very recently, so I guess nobody has been really using it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never worked with progress callbacks... do we call them on EACH log line? That smells like a massive overhead. So once people start using this I see a massive performance problem coming or you can not scale.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah its called per line. I guess for the use case I have in mind, the log output isn't particularly verbose, but I can imagine situations where it would be an issue.


self.log.info("[%s] %s", container_name, message_to_log)
message_to_log = message
elif message_to_log: # continuation of the previous log line
message_to_log = f"{message_to_log}\n{message}"
finally:
# log the last line and update the last_captured_timestamp
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
for callback in self._callbacks:
cb = callback.progress_callback(
line=message_to_log,
client=v1_api,
mode=ExecutionMode.ASYNC,
container_name=container_name,
timestamp=line_timestamp,
pod=pod,
)
if asyncio.iscoroutine(cb):
await cb

self.log.info("[%s] %s", container_name, message_to_log)
return now # Return the current time as the last log time to ensure logs from the current second are read in the next fetch.
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ def test_wait_until_job_complete(

@pytest.mark.parametrize("do_xcom_push", [True, False])
@pytest.mark.parametrize("get_logs", [True, False])
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._write_logs"))
def test_execute_complete(self, mocked_write_logs, get_logs, do_xcom_push):
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"))
def test_execute_complete(self, mock_manager, get_logs, do_xcom_push):
mock_ti = mock.MagicMock()
context = {"ti": mock_ti}
mock_job = mock.MagicMock()
Expand All @@ -839,9 +839,9 @@ def test_execute_complete(self, mocked_write_logs, get_logs, do_xcom_push):
mock_ti.xcom_push.assert_called_once_with(key="job", value=mock_job)

if get_logs:
mocked_write_logs.assert_called_once()
mock_manager.fetch_requested_container_logs.assert_called_once()
else:
mocked_write_logs.assert_not_called()
mock_manager.fetch_requested_container_logs.assert_not_called()

@pytest.mark.non_db_test_override
def test_execute_complete_fail(self):
Expand Down
Loading
Loading