Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b841b49
Create a generic callbacks class for KubernetesPodOperator
hussein-awala Nov 17, 2023
dde518d
Merge branch 'main' into kpo_callbacks
hussein-awala Nov 28, 2023
2797aef
Trigger tests with old-style union
hussein-awala Nov 28, 2023
6b884f0
Fix GCP K8S test
hussein-awala Nov 28, 2023
3ddfd82
Fix callback param type and cleanup calls
hussein-awala Nov 28, 2023
d245298
Some fixes and add unit tests
hussein-awala Nov 28, 2023
f722e40
Replace type by Type
hussein-awala Nov 28, 2023
c4ae081
Reset mock_callbacks in pod manager tests
hussein-awala Nov 28, 2023
c13b95e
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 9, 2024
6173c1b
Fix static checks
hussein-awala Jan 9, 2024
b1755af
Add a doc paragraph for the new callbacks
hussein-awala Jan 10, 2024
1860881
Add a check for cncf-kuberntes version in google provider
hussein-awala Jan 10, 2024
479b3c8
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 11, 2024
d9fd475
Switch to None default value and bump min cncf-k8s provider in google…
hussein-awala Jan 12, 2024
255b060
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 12, 2024
39ece30
Fix tests
hussein-awala Jan 13, 2024
1d133d7
Reduce check intervals to avoid killing the asyncio task
hussein-awala Jan 13, 2024
be57a68
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 19, 2024
79354e8
Revert async callbacks
hussein-awala Jan 20, 2024
f56a925
fix breeze tests
hussein-awala Jan 20, 2024
3470855
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions airflow/providers/cncf/kubernetes/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from enum import Enum
from typing import Union

import kubernetes.client as k8s
import kubernetes_asyncio.client as async_k8s

client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]


class ExecutionMode(str, Enum):
"""Enum class for execution mode."""

SYNC = "sync"
ASYNC = "async"


class KubernetesPodOperatorCallback:
"""`KubernetesPodOperator` callbacks methods."""

@staticmethod
def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
"""Callback method called after creating the sync client.

:param client: the created `kubernetes.client.CoreV1Api` client.
"""
pass

@staticmethod
def on_async_client_creation(*, client: async_k8s.CoreV1Api, **kwargs) -> None:
"""Callback method called after creating the async client.

:param client: the created `kubernetes_asyncio.client.CoreV1Api` client.
"""
pass

@staticmethod
def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called after creating the pod.

:param pod: the created pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called when the pod starts.

:param pod: the started pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called when the pod completes.

:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):
"""Callback method called after cleaning/deleting the pod.

:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_operator_resuming(
*, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs
) -> None:
"""Callback method called when resuming the `KubernetesPodOperator` from deferred state.

:param pod: the current state of the pod.
:param event: the returned event from the Trigger.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def progress_callback(*, line: str, client: client_type, mode: str, **kwargs) -> None:
"""Callback method to process pod container logs.

:param line: the read line of log.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass
48 changes: 45 additions & 3 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
convert_volume,
convert_volume_mount,
)
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
POD_NAME_MAX_LENGTH,
Expand Down Expand Up @@ -197,7 +198,10 @@ class KubernetesPodOperator(BaseOperator):
Default value is "File"
:param active_deadline_seconds: The active_deadline_seconds which translates to active_deadline_seconds
in V1PodSpec.
:param callbacks: KubernetesPodOperatorCallback instance contains the callbacks methods on different step
of KubernetesPodOperator.
:param progress_callback: Callback function for receiving k8s container logs.
`progress_callback` is deprecated, please use :param `callbacks` instead.
"""

# !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
Expand Down Expand Up @@ -289,6 +293,7 @@ def __init__(
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -380,6 +385,7 @@ def __init__(

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
self.callbacks = callbacks

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -457,7 +463,9 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool

@cached_property
def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client, progress_callback=self._progress_callback)
return PodManager(
kube_client=self.client, callbacks=self.callbacks, progress_callback=self._progress_callback
)

@cached_property
def hook(self) -> PodOperatorHookProtocol:
Expand All @@ -471,7 +479,10 @@ def hook(self) -> PodOperatorHookProtocol:

@cached_property
def client(self) -> CoreV1Api:
return self.hook.core_v1_client
client = self.hook.core_v1_client
if self.callbacks:
self.callbacks.on_sync_client_creation(client=client)
return client

def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
"""Return an already-running pod for this task instance if one exists."""
Expand Down Expand Up @@ -550,7 +561,17 @@ def execute_sync(self, context: Context):

# get remote pod for use in cleanup methods
self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
)
self.await_pod_start(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_starting(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
Comment thread
potiuk marked this conversation as resolved.

if self.get_logs:
self.pod_manager.fetch_requested_container_logs(
Expand All @@ -564,6 +585,12 @@ def execute_sync(self, context: Context):
self.pod_manager.await_container_completion(
pod=self.pod, container_name=self.base_container_name
)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)

if self.do_xcom_push:
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
Expand All @@ -573,10 +600,13 @@ def execute_sync(self, context: Context):
self.pod, istio_enabled, self.base_container_name
)
finally:
pod_to_clean = self.pod or self.pod_request_obj
self.cleanup(
pod=self.pod or self.pod_request_obj,
pod=pod_to_clean,
remote_pod=self.remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)
if self.do_xcom_push:
return result

Expand All @@ -586,6 +616,12 @@ def execute_async(self, context: Context):
pod_request_obj=self.pod_request_obj,
context=context,
)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
Expand Down Expand Up @@ -622,6 +658,10 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
event["name"],
event["namespace"],
)
if self.callbacks:
self.callbacks.on_operator_resuming(
pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
Expand Down Expand Up @@ -674,6 +714,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs):
pod=pod,
remote_pod=remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)

def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
istio_enabled = self.is_istio_enabled(remote_pod)
Expand Down
45 changes: 44 additions & 1 deletion airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
startup_check_interval: int = 1,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -102,6 +104,7 @@ def __init__(
self.get_logs = get_logs
self.startup_timeout = startup_timeout
self.startup_check_interval = startup_check_interval
self.callbacks = callbacks

if should_delete_pod is not None:
warnings.warn(
Expand Down Expand Up @@ -137,12 +140,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"trigger_start_time": self.trigger_start_time,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
"callbacks": self.callbacks,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pod status and yield a TriggerEvent."""
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
_is_starting_callback_called = False
try:
while True:
pod = await self.hook.get_pod(
Expand All @@ -157,6 +162,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.debug("Container %s status: %s", self.base_container_name, container_state)

if container_state == ContainerState.TERMINATED:
if self.callbacks and not _is_starting_callback_called:
self.callbacks.on_pod_starting(
pod=pod,
client=self._get_async_hook().core_v1_client,
mode=ExecutionMode.ASYNC,
)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC
)
yield TriggerEvent(
{
"name": self.pod_name,
Expand Down Expand Up @@ -189,9 +204,28 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.info("Sleeping for %s seconds.", self.startup_check_interval)
await asyncio.sleep(self.startup_check_interval)
else:
if self.callbacks and not _is_starting_callback_called:
# if the trigger fails and re-run on a different triggerer, this callback could
# be called again
self.callbacks.on_pod_starting(
pod=pod,
client=self._get_async_hook().core_v1_client,
mode=ExecutionMode.ASYNC,
)
_is_starting_callback_called = True
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
if self.callbacks and not _is_starting_callback_called:
self.callbacks.on_pod_starting(
pod=pod,
client=self._get_async_hook().core_v1_client,
mode=ExecutionMode.ASYNC,
)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC
)
yield TriggerEvent(
{
"name": self.pod_name,
Expand All @@ -215,6 +249,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(
pod=await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace),
client=self._get_async_hook().core_v1_client,
mode=ExecutionMode.ASYNC,
)
yield TriggerEvent(
{
"name": self.pod_name,
Expand All @@ -237,12 +277,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]

def _get_async_hook(self) -> AsyncKubernetesHook:
# TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider.
return AsyncKubernetesHook(
_hook = AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
if self.callbacks:
self.callbacks.on_async_client_creation(client=_hook.core_v1_client)
return _hook

@cached_property
def hook(self) -> AsyncKubernetesHook:
Expand Down
Loading