Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate public endpoint Patch Task Instance to FastAPI #44223

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def set_mapped_task_instance_note(
return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index)


@mark_fastapi_migration_done
@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
@provide_session
Expand Down
25 changes: 25 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
ConfigDict,
Field,
NonNegativeInt,
StringConstraints,
field_validator,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
Expand Down Expand Up @@ -150,3 +152,26 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):

task_instances: list[TaskInstanceHistoryResponse]
total_entries: int


class PatchTaskInstanceBody(BaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
new_state: str | None = None
note: Annotated[str, StringConstraints(max_length=1000)] | None = None

@field_validator("new_state", mode="before")
@classmethod
def validate_new_state(cls, ns: str | None) -> str:
"""Validate new_state."""
valid_states = [
vs.name.lower()
for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED)
]
if ns is None:
raise ValueError("'new_state' should not be empty")
ns = ns.lower()
if ns not in valid_states:
raise ValueError(f"'{ns}' is not one of {valid_states}")
return ns
189 changes: 189 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3390,6 +3390,91 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Task Instance
summary: Patch Task Instance
description: Update the state of a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: task_id
in: path
required: true
schema:
type: string
title: Task Id
- name: map_index
in: query
required: false
schema:
type: integer
default: -1
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PatchTaskInstanceBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped:
get:
tags:
Expand Down Expand Up @@ -3780,6 +3865,90 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Task Instance
summary: Patch Task Instance
description: Update the state of a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: task_id
in: path
required: true
schema:
type: string
title: Task Id
- name: map_index
in: path
required: true
schema:
type: integer
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PatchTaskInstanceBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances:
get:
tags:
Expand Down Expand Up @@ -6409,6 +6578,26 @@ components:
- unixname
title: JobResponse
description: Job serializer for responses.
PatchTaskInstanceBody:
properties:
dry_run:
type: boolean
title: Dry Run
default: true
new_state:
anyOf:
- type: string
- type: 'null'
title: New State
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
type: object
title: PatchTaskInstanceBody
description: Request body for Clear Task Instances endpoint.
PluginCollectionResponse:
properties:
plugins:
Expand Down
87 changes: 85 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from typing import Annotated, Literal

from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import select

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.parameters import (
Expand All @@ -48,6 +49,7 @@
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.task_instances import (
PatchTaskInstanceBody,
TaskDependencyCollectionResponse,
TaskInstanceCollectionResponse,
TaskInstanceHistoryResponse,
Expand Down Expand Up @@ -472,3 +474,84 @@ def get_mapped_task_instance_try_details(
map_index=map_index,
session=session,
)


@task_instances_router.patch(
"/{task_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
@task_instances_router.patch(
"/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
def patch_task_instance(
dag_id: str,
dag_run_id: str,
task_id: str,
request: Request,
body: PatchTaskInstanceBody,
session: Annotated[Session, Depends(get_session)],
map_index: int = -1,
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
update_mask: list[str] | None = Query(None),
) -> TaskInstanceResponse:
"""Update the state of a task instance."""
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")

if not dag.has_task(task_id):
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'")

query = (
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.options(joinedload(TI.rendered_task_instance_fields))
)
if map_index == -1:
query = query.where(or_(TI.map_index == -1, TI.map_index is None))
else:
query = query.where(TI.map_index == map_index)

try:
ti = session.scalar(query)
except MultipleResultsFound:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
)

err_msg_404 = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}"
if ti is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)

fields_to_update = body.model_fields_set
if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)

for field in fields_to_update:
if field == "new_state":
if not body.dry_run:
tis: list[TI] = dag.set_task_instance_state(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these to be included for updation as well? Let me know if required I'll add these too.

upstream=body.include_upstream,
downstream=body.include_downstream,
future=body.include_future,
past=body.include_past,
...

task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
state=body.new_state,
commit=True,
session=session,
)
if not ti:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
ti = tis[0] if isinstance(tis, list) else tis
elif field == "note":
if update_mask or body.note is not None:
# @TODO: replace None passed for user_id with actual user id when
# permissions and auth is in place.
if ti.task_instance_note is None:
ti.note = (body.note, None)
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = None
session.commit()

return TaskInstanceResponse.model_validate(ti, from_attributes=True)
6 changes: 6 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,12 @@ export type DagServicePatchDagMutationResult = Awaited<
export type PoolServicePatchPoolMutationResult = Awaited<
ReturnType<typeof PoolService.patchPool>
>;
export type TaskInstanceServicePatchTaskInstanceMutationResult = Awaited<
ReturnType<typeof TaskInstanceService.patchTaskInstance>
>;
export type TaskInstanceServicePatchTaskInstance1MutationResult = Awaited<
ReturnType<typeof TaskInstanceService.patchTaskInstance1>
>;
export type VariableServicePatchVariableMutationResult = Awaited<
ReturnType<typeof VariableService.patchVariable>
>;
Expand Down
Loading