diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 8d3d905be9552..4bd0de7e29e4a 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -212,7 +212,7 @@ def process_status( self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version)) elif status == "Succeeded": self.log.info("Event: %s Succeeded", pod_id) - self.watcher_queue.put((pod_id, namespace, None, annotations, resource_version)) + self.watcher_queue.put((pod_id, namespace, State.SUCCESS, annotations, resource_version)) elif status == "Running": if event["type"] == "DELETED": self.log.info("Event: Pod %s deleted before it could complete", pod_id) @@ -751,19 +751,26 @@ def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, na if TYPE_CHECKING: assert self.kube_scheduler - if state != State.RUNNING: - if self.kube_config.delete_worker_pods: - if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure: - self.kube_scheduler.delete_pod(pod_id, namespace) - self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace)) - else: - self.kube_scheduler.patch_pod_executor_done(pod_id=pod_id, namespace=namespace) - self.log.info("Patched pod %s in namespace %s to mark it as done", str(key), str(namespace)) - try: - self.running.remove(key) - except KeyError: - self.log.debug("Could not find key: %s", str(key)) - self.event_buffer[key] = state, None + if state == State.RUNNING: + self.event_buffer[key] = state, None + return + + if self.kube_config.delete_worker_pods: + if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure: + self.kube_scheduler.delete_pod(pod_id, namespace) + self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace)) + else: + self.kube_scheduler.patch_pod_executor_done(pod_id=pod_id, namespace=namespace) + self.log.info("Patched pod %s in namespace %s to mark it as done", str(key), str(namespace)) + + try: + self.running.remove(key) + except KeyError: + self.log.debug("TI key not in running, not adding to event_buffer: %s", key) + else: + # We get multiple events once the pod hits a terminal state, and we only want to + # do this once, so only do it when we remove the task from running + self.event_buffer[key] = state, None def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id] @@ -841,6 +848,8 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: ) except ApiException as e: self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e) + pod_id = annotations_to_key(pod.metadata.annotations) + self.running.add(pod_id) def _flush_task_queue(self) -> None: if TYPE_CHECKING: diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 8e30835aa3e73..72052622a7d98 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -519,8 +519,10 @@ def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_wa executor.start() try: key = ("dag_id", "task_id", "run_id", "try_number1") + executor.running = {key} executor._change_state(key, State.RUNNING, "pod_id", "default") assert executor.event_buffer[key][0] == State.RUNNING + assert executor.running == {key} finally: executor.end() @@ -532,8 +534,10 @@ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_ executor.start() try: key = ("dag_id", "task_id", "run_id", "try_number2") + executor.running = {key} executor._change_state(key, State.SUCCESS, "pod_id", "default") assert executor.event_buffer[key][0] == State.SUCCESS + assert executor.running == set() mock_delete_pod.assert_called_once_with("pod_id", "default") finally: executor.end() @@ -550,8 +554,10 @@ def test_change_state_failed_no_deletion( executor.start() try: key = ("dag_id", "task_id", "run_id", "try_number3") + executor.running = {key} executor._change_state(key, State.FAILED, "pod_id", "default") assert executor.event_buffer[key][0] == State.FAILED + assert executor.running == set() mock_delete_pod.assert_not_called() finally: executor.end() @@ -594,8 +600,10 @@ def test_change_state_skip_pod_deletion( executor.start() try: key = ("dag_id", "task_id", "run_id", "try_number2") + executor.running = {key} executor._change_state(key, State.SUCCESS, "pod_id", "test-namespace") assert executor.event_buffer[key][0] == State.SUCCESS + assert executor.running == set() mock_delete_pod.assert_not_called() mock_patch_pod.assert_called_once_with(pod_id="pod_id", namespace="test-namespace") finally: @@ -615,8 +623,10 @@ def test_change_state_failed_pod_deletion( executor.start() try: key = ("dag_id", "task_id", "run_id", "try_number2") + executor.running = {key} executor._change_state(key, State.FAILED, "pod_id", "test-namespace") assert executor.event_buffer[key][0] == State.FAILED + assert executor.running == set() mock_delete_pod.assert_called_once_with("pod_id", "test-namespace") mock_patch_pod.assert_not_called() finally: @@ -765,17 +775,27 @@ def test_adopt_completed_pods(self, mock_kube_client): executor.kube_client = mock_kube_client executor.kube_config.kube_namespace = "somens" pod_names = ["one", "two"] + + def get_annotations(pod_name): + return { + "dag_id": "dag", + "run_id": "run_id", + "task_id": pod_name, + "try_number": "1", + } + mock_kube_client.list_namespaced_pod.return_value.items = [ k8s.V1Pod( metadata=k8s.V1ObjectMeta( name=pod_name, labels={"airflow-worker": pod_name}, - annotations={"some_annotation": "hello"}, + annotations=get_annotations(pod_name), namespace="somens", ) ) for pod_name in pod_names ] + expected_running_ti_keys = {annotations_to_key(get_annotations(pod_name)) for pod_name in pod_names} executor._adopt_completed_pods(mock_kube_client) mock_kube_client.list_namespaced_pod.assert_called_once_with( @@ -795,6 +815,7 @@ def test_adopt_completed_pods(self, mock_kube_client): ], any_order=True, ) + assert executor.running == expected_running_ti_keys @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") def test_not_adopt_unassigned_task(self, mock_kube_client): @@ -1265,7 +1286,7 @@ def test_process_status_succeeded(self): self.events.append({"type": "MODIFIED", "object": self.pod}) self._run() - self.assert_watcher_queue_called_once_with_state(None) + self.assert_watcher_queue_called_once_with_state(State.SUCCESS) def test_process_status_running_deleted(self): self.pod.status.phase = "Running"