diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 27c9439dbde1d..9b106d28b5414 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -17,11 +17,13 @@ """Launches PODs""" import json import math +import multiprocessing import time import warnings from contextlib import closing from dataclasses import dataclass from datetime import datetime +from multiprocessing.sharedctypes import RawValue from typing import TYPE_CHECKING, Iterable, Optional, Tuple, cast import pendulum @@ -193,6 +195,40 @@ def follow_container_logs(self, pod: V1Pod, container_name: str) -> PodLoggingSt ) return self.fetch_container_logs(pod=pod, container_name=container_name, follow=True) + def log_iterable(self, logs: Iterable[bytes]) -> Optional[DateTime]: + timestamp = None + for line in logs: + timestamp, message = self.parse_log_line(line.decode('utf-8', errors="backslashreplace")) + self.log.info(message) + return timestamp + + def consume_container_logs_stream( + self, pod: V1Pod, container_name: str, stream: Iterable[bytes] + ) -> Optional[DateTime]: + def log_iterable_and_set_value(timestamp): + dt = self.log_iterable(stream) + if dt is not None: + timestamp.value = dt.timestamp() # type: ignore[attr-defined] + + timestamp = RawValue('f') # read and write are synchronous so rawvalue is enough + p = multiprocessing.Process(target=log_iterable_and_set_value, args=(timestamp,)) + p.start() + self.await_container_completion(pod, container_name) + p.join(timeout=5) + if p.is_alive(): + p.terminate() + p.join() + self.log.warning( + "Container %s log read was interrupted at some point caused by log rotation " + "see https://github.com/apache/airflow/issues/23497 for reference.", + container_name, + ) + return None + p.close() + if not timestamp.value: + return None + return pendulum.from_timestamp(timestamp.value) + def fetch_container_logs( self, pod: V1Pod, container_name: str, *, follow=False, since_time: Optional[DateTime] = None ) -> PodLoggingStatus: @@ -220,10 +256,11 @@ def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True) ), follow=follow, ) - for raw_line in logs: - line = raw_line.decode('utf-8', errors="backslashreplace") - timestamp, message = self.parse_log_line(line) - self.log.info(message) + if follow: + timestamp = self.consume_container_logs_stream(pod, container_name, logs) + else: + timestamp = self.log_iterable(logs) + except BaseHTTPError as e: self.log.warning( "Reading of logs interrupted with error %r; will retry. " diff --git a/kubernetes_tests/kubernetes_test_utils.py b/kubernetes_tests/kubernetes_test_utils.py new file mode 100644 index 0000000000000..d851b50038618 --- /dev/null +++ b/kubernetes_tests/kubernetes_test_utils.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import multiprocessing + + +class SharedLogger: + def __init__(self, str_to_count: str, level: str): + self.counter = multiprocessing.Value('i') + self.str_to_count = str_to_count + self.level = level + + def info(self, message, *args): + self._count(message, "info") + + def warning(self, message, *args): + self._count(message, "warning") + + def debug(self, message, *args): + self._count(message, "debug") + + def error(self, message, *args): + self._count(message, "error") + + def _count(self, message, level): + if level != self.level: + return + if message == self.str_to_count: + self.counter.value += 1 diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 49928274517ac..14f6ab7c88f3a 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -42,6 +42,7 @@ from airflow.utils import timezone from airflow.utils.types import DagRunType from airflow.version import version as airflow_version +from kubernetes_tests.kubernetes_test_utils import SharedLogger HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" @@ -430,7 +431,9 @@ def test_port(self): assert self.expected_pod == actual_pod def test_volume_mount(self): - with mock.patch.object(PodManager, 'log') as mock_logger: + with mock.patch.object( + PodManager, 'log', new=SharedLogger("retrieved from mount", "info") + ) as mock_logger: volume_mount = k8s.V1VolumeMount( name='test-volume', mount_path='/tmp/test_volume', sub_path=None, read_only=False ) @@ -459,7 +462,9 @@ def test_volume_mount(self): ) context = create_context(k) k.execute(context=context) - mock_logger.info.assert_any_call('retrieved from mount') + + assert mock_logger.counter.value == 1 + actual_pod = self.api_client.sanitize_for_serialization(k.pod) self.expected_pod['spec']['containers'][0]['args'] = args self.expected_pod['spec']['containers'][0]['volumeMounts'] = [ diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py index 5a4efc73d4383..5cf72abcda4df 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py +++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py @@ -42,6 +42,7 @@ from airflow.utils import timezone from airflow.utils.types import DagRunType from airflow.version import version as airflow_version +from kubernetes_tests.kubernetes_test_utils import SharedLogger # noinspection DuplicatedCode @@ -276,7 +277,7 @@ def test_port(self): assert self.expected_pod == actual_pod def test_volume_mount(self): - with patch.object(PodManager, 'log') as mock_logger: + with patch.object(PodManager, 'log', new=SharedLogger("retrieved from mount", "info")) as mock_logger: volume_mount = VolumeMount( 'test-volume', mount_path='/tmp/test_volume', sub_path=None, read_only=False ) @@ -303,7 +304,7 @@ def test_volume_mount(self): ) context = create_context(k) k.execute(context=context) - mock_logger.info.assert_any_call('retrieved from mount') + assert mock_logger.counter.value == 1 actual_pod = self.api_client.sanitize_for_serialization(k.pod) expected_pod = copy(self.expected_pod) expected_pod['spec']['containers'][0]['args'] = args diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index 8070c3c3532b5..3976bcb2bf794 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import logging +import time +from typing import Generator from unittest import mock from unittest.mock import MagicMock @@ -322,13 +324,39 @@ def test_fetch_container_running_follow( When called with follow=False, should return immediately even though still running. """ mock_pod = MagicMock() - container_running_mock.side_effect = [True, True, False] # only will be called once + container_running_mock.side_effect = [True, False, False] # called once when follow=False self.mock_kube_client.read_namespaced_pod_log.return_value = [b'2021-01-01 hi'] ret = self.pod_manager.fetch_container_logs(pod=mock_pod, container_name='base', follow=follow) assert len(container_running_mock.call_args_list) == is_running_calls assert ret.last_log_time == DateTime(2021, 1, 1, tzinfo=Timezone('UTC')) assert ret.running is exp_running + @pytest.mark.parametrize( + 'follow, is_running_calls, exp_running, producing_logs', + [(True, 3, False, False), (True, 3, False, True)], + ) + @mock.patch('airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running') + def test_fetch_container_running_follow_when_kube_api_hangs( + self, container_running_mock, follow, is_running_calls, exp_running, producing_logs + ): + """ + When called with follow, should keep looping even after disconnections, if pod still running. + """ + mock_pod = MagicMock() + container_running_mock.side_effect = [True, False, False] + + def stream_logs() -> Generator: + while True: + time.sleep(1) # this is intentional: urllib3.response.stream() is not async + if producing_logs: + yield b'2021-01-01 hi' + + self.mock_kube_client.read_namespaced_pod_log.return_value = stream_logs() + ret = self.pod_manager.fetch_container_logs(pod=mock_pod, container_name='base', follow=follow) + assert len(container_running_mock.call_args_list) == is_running_calls + assert ret.running is exp_running + assert ret.last_log_time is None + def params_for_test_container_is_running(): """The `container_is_running` method is designed to handle an assortment of bad objects