-
Notifications
You must be signed in to change notification settings - Fork 14.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AIP 72: Handling deferrable tasks in execution API as well as TASK SDK #44241
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
from __future__ import annotations | ||
|
||
import uuid | ||
from typing import Annotated, Literal, Union | ||
from typing import Annotated, Any, Literal, Union | ||
|
||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema | ||
|
||
|
@@ -61,6 +61,30 @@ class TITargetStatePayload(BaseModel): | |
state: IntermediateTIState | ||
|
||
|
||
class TIDeferredStatePayload(BaseModel): | ||
"""Schema for updating TaskInstance to a deferred state.""" | ||
|
||
state: Annotated[ | ||
Literal[IntermediateTIState.DEFERRED], | ||
# Specify a default in the schema, but not in code, so Pydantic marks it as required. | ||
WithJsonSchema( | ||
{ | ||
"type": "string", | ||
"enum": [IntermediateTIState.DEFERRED], | ||
"default": IntermediateTIState.DEFERRED, | ||
} | ||
), | ||
] | ||
|
||
classpath: str | ||
kwargs: dict[str, Any] | ||
created_date: UtcDateTime | ||
next_method: str | ||
# need to serialise to datetime.timedelta | ||
timeout: str | None | ||
# what about triggerer_id? | ||
|
||
|
||
def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: | ||
""" | ||
Determine the discriminator key for TaskInstance state transitions. | ||
|
@@ -78,6 +102,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: | |
return str(state) | ||
elif state in set(TerminalTIState): | ||
return "_terminal_" | ||
elif state == "deferred": | ||
return "deferred" | ||
return "_other_" | ||
|
||
|
||
|
@@ -88,6 +114,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: | |
Annotated[TIEnterRunningPayload, Tag("running")], | ||
Annotated[TITerminalStatePayload, Tag("_terminal_")], | ||
Annotated[TITargetStatePayload, Tag("_other_")], | ||
Annotated[TIDeferredStatePayload, Tag("deferred")], | ||
], | ||
Discriminator(ti_state_discriminator), | ||
] | ||
|
@@ -105,7 +132,7 @@ class TIHeartbeatInfo(BaseModel): | |
class TaskInstance(BaseModel): | ||
"""Schema for TaskInstance model with minimal required fields needed for Runtime.""" | ||
|
||
id: uuid.UUID | ||
id: uuid.UUID | int = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can't be 0 |
||
|
||
task_id: str | ||
dag_id: str | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,8 @@ | |
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Annotated | ||
from datetime import timedelta | ||
from typing import TYPE_CHECKING, Annotated | ||
from uuid import UUID | ||
|
||
from fastapi import Body, Depends, HTTPException, status | ||
|
@@ -30,14 +31,16 @@ | |
from airflow.api_fastapi.common.db.common import get_session | ||
from airflow.api_fastapi.common.router import AirflowRouter | ||
from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( | ||
TIDeferredStatePayload, | ||
TIEnterRunningPayload, | ||
TIHeartbeatInfo, | ||
TIStateUpdate, | ||
TITerminalStatePayload, | ||
) | ||
from airflow.models import Trigger | ||
from airflow.models.taskinstance import TaskInstance as TI | ||
from airflow.utils import timezone | ||
from airflow.utils.state import State | ||
from airflow.utils.state import State, TaskInstanceState | ||
|
||
# TODO: Add dependency on JWT token | ||
router = AirflowRouter() | ||
|
@@ -122,6 +125,54 @@ def ti_update_state( | |
) | ||
elif isinstance(ti_patch_payload, TITerminalStatePayload): | ||
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) | ||
elif isinstance(ti_patch_payload, TIDeferredStatePayload): | ||
trigger_row = Trigger( | ||
classpath=ti_patch_payload.classpath, | ||
kwargs=ti_patch_payload.kwargs, | ||
created_date=ti_patch_payload.created_date, | ||
) | ||
session.add(trigger_row) | ||
session.flush() | ||
|
||
ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none() | ||
|
||
if not ti: | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail={ | ||
"message": f"TaskInstance with id {ti_id_str} not found.", | ||
}, | ||
) | ||
|
||
ti.state = TaskInstanceState.DEFERRED | ||
ti.trigger_id = trigger_row.id | ||
ti.next_method = ti_patch_payload.next_method | ||
ti.next_kwargs = ti_patch_payload.kwargs or {} | ||
# handle properly based on client | ||
timeout = ti_patch_payload.timeout | ||
# Calculate timeout too if it was passed | ||
if timeout is not None: | ||
ti.trigger_timeout = timezone.utcnow() + timedelta(days=int(timeout)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Timeout shouldn't be days! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to handle it better based on what I can send from the client, this isn't right yet |
||
else: | ||
ti.trigger_timeout = None | ||
|
||
# If an execution_timeout is set, set the timeout to the minimum of | ||
# it and the trigger timeout | ||
if ti.task: | ||
execution_timeout = ti.task.execution_timeout | ||
if execution_timeout: | ||
if TYPE_CHECKING: | ||
assert ti.start_date | ||
if ti.trigger_timeout: | ||
ti.trigger_timeout = min(ti.start_date + execution_timeout, ti.trigger_timeout) | ||
else: | ||
ti.trigger_timeout = ti.start_date + execution_timeout | ||
|
||
session.merge(ti) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merge isn't needed - the object is already attached to the session There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commit only will do, right? |
||
session.commit() | ||
|
||
log.info("TI %s state updated to: deferred", ti_id_str) | ||
return | ||
|
||
# TODO: Replace this with FastAPI's Custom Exception handling: | ||
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,15 +17,17 @@ | |
|
||
from __future__ import annotations | ||
|
||
from datetime import datetime | ||
from unittest import mock | ||
|
||
import pytest | ||
from sqlalchemy import select | ||
from sqlalchemy.exc import SQLAlchemyError | ||
|
||
from airflow.models import Trigger | ||
from airflow.models.taskinstance import TaskInstance | ||
from airflow.utils import timezone | ||
from airflow.utils.state import State | ||
from airflow.utils.state import State, TaskInstanceState | ||
|
||
from tests_common.test_utils.db import clear_db_runs | ||
|
||
|
@@ -193,6 +195,50 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta | |
assert response.status_code == 500 | ||
assert response.json()["detail"] == "Database error occurred" | ||
|
||
def test_ti_update_state_to_deferred(self, client, session, create_task_instance): | ||
""" | ||
Test that tests if the transition to deferred state is handled correctly. | ||
""" | ||
clear_db_runs() | ||
|
||
ti = create_task_instance( | ||
task_id="test_ti_update_state_to_deferred", | ||
state=State.RUNNING, | ||
session=session, | ||
) | ||
session.commit() | ||
|
||
payload = { | ||
"state": "deferred", | ||
"classpath": "my-class-path", | ||
"kwargs": {}, | ||
"created_date": "2024-10-31T12:00:00Z", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, why is this being sent in the payload? Should it be the time the server received the request instead? |
||
"next_method": "execute_callback", | ||
"timeout": None, | ||
} | ||
|
||
response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) | ||
|
||
assert response.status_code == 204 | ||
assert response.text == "" | ||
|
||
session.expire_all() | ||
|
||
t = session.query(Trigger).all() | ||
assert len(t) == 1 | ||
assert t[0].created_date == datetime(2024, 10, 31, 12, 0, tzinfo=timezone.utc) | ||
assert t[0].classpath == "my-class-path" | ||
assert t[0].kwargs == {} | ||
|
||
tis = session.query(TaskInstance).all() | ||
assert len(tis) == 1 | ||
|
||
assert tis[0].state == TaskInstanceState.DEFERRED | ||
assert tis[0].trigger_id == t[0].id | ||
assert tis[0].next_method == "execute_callback" | ||
assert tis[0].next_kwargs == {} | ||
assert tis[0].trigger_timeout is None | ||
|
||
|
||
class TestTIHealthEndpoint: | ||
def setup_method(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gets set in the db once the API processes the request, so it shouldn't/can't be part of the request
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trigger_id right? I need to remove this comment then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah