Skip to content

Commit

Permalink
Make changes as per feedback from Pierre
Browse files Browse the repository at this point in the history
  • Loading branch information
omkar-foss committed Nov 21, 2024
1 parent c8d2c64 commit 8cd608c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 46 deletions.
16 changes: 1 addition & 15 deletions airflow/api_fastapi/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from datetime import datetime, timedelta
from datetime import timedelta
from enum import Enum
from typing import Annotated

Expand Down Expand Up @@ -66,20 +66,6 @@ class TimeDelta(BaseModel):
TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)]


def _validate_nonnaive_datetime_field(dt: datetime | None) -> datetime | None:
"""Validate and return the datetime field."""
if dt is None:
return None
if isinstance(dt, str):
dt = datetime.fromisoformat(dt)
if not dt.tzinfo:
raise ValueError("Invalid datetime format, Naive datetime is disallowed")
return dt


DatetimeWithNonNaiveValidation = Annotated[datetime, BeforeValidator(_validate_nonnaive_datetime_field)]


class Mimetype(str, Enum):
"""Mimetype for the `Content-Type` header."""

Expand Down
8 changes: 3 additions & 5 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Annotated, Any

from pydantic import (
AliasChoices,
AliasPath,
AwareDatetime,
BaseModel,
Expand All @@ -32,7 +31,6 @@
model_validator,
)

from airflow.api_fastapi.common.types import DatetimeWithNonNaiveValidation
from airflow.api_fastapi.core_api.datamodels.job import JobResponse
from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -160,8 +158,8 @@ class ClearTaskInstancesBody(BaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
start_date: DatetimeWithNonNaiveValidation | None = None
end_date: DatetimeWithNonNaiveValidation | None = None
start_date: AwareDatetime | None = None
end_date: AwareDatetime | None = None
only_failed: bool = True
only_running: bool = False
reset_dag_runs: bool = False
Expand Down Expand Up @@ -196,7 +194,7 @@ class TaskInstanceReferenceResponse(BaseModel):
"""Task Instance Reference serializer for responses."""

task_id: str
dag_run_id: str = Field(validation_alias=AliasChoices("run_id"))
dag_run_id: str = Field(validation_alias="run_id")
dag_id: str
logical_date: datetime

Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def get_mapped_task_instance(


@task_instances_router.get(
task_instances_prefix + "",
task_instances_prefix,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_task_instances(
Expand Down Expand Up @@ -541,15 +541,15 @@ def post_clear_task_instances(

task_instances = dag.clear(
dry_run=True,
task_ids=body.task_ids,
task_ids=task_ids,
dag_bag=request.app.state.dag_bag,
**body.model_dump(
include=[ # type: ignore[arg-type]
include={
"start_date",
"end_date",
"only_failed",
"only_running",
]
}
),
)

Expand Down
59 changes: 37 additions & 22 deletions tests/api_fastapi/core_api/routes/public/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@

import pendulum
import pytest
import sqlalchemy

from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.taskinstancehistory import TaskInstanceHistory
Expand Down Expand Up @@ -1828,21 +1830,36 @@ def test_should_respond_200(
assert response.status_code == 200
assert len(response.json()["task_instances"]) == expected_ti

def test_clear_taskinstance_is_called_with_queued_dr_state(self, test_client, session):
@mock.patch("airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances")
def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, test_client, session):
"""Test that if reset_dag_runs is True, then clear_task_instances is called with State.QUEUED"""
self.create_task_instances(session)
dag_id = "example_python_operator"
payload = {"reset_dag_runs": True, "dry_run": False}
self.dagbag.sync_to_db()
with mock.patch(
"airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances",
) as mp:
response = test_client.post(
f"/public/dags/{dag_id}/clearTaskInstances",
json=payload,
)
assert response.status_code == 200
mp.assert_called_once()
response = test_client.post(
f"/public/dags/{dag_id}/clearTaskInstances",
json=payload,
)
assert response.status_code == 200

# We check args individually instead of direct matching using
# assert_called_once_with(), because the session objects don't match
# and can't be skipped using mock.ANY.
mock_clearti.assert_called_once()
args, kwargs = mock_clearti.call_args
assert len(args) == 4
assert len(kwargs) == 0
# 1st argument
assert args[0] == []
# 2nd argument
assert args[1] is not None
assert isinstance(args[1], sqlalchemy.orm.session.Session)
# 3rd argument
assert args[2].dag_id, dag_id
assert isinstance(args[2], DAG)
# 4th argument
assert args[3] == DagRunState.QUEUED

def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, session):
"""Test that dagrun is running when invalid task_ids are passed to clearTaskInstances API."""
Expand Down Expand Up @@ -2227,11 +2244,10 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, test_client, session
{
"detail": [
{
"type": "value_error",
"type": "timezone_aware",
"loc": ["body", "end_date"],
"msg": "Value error, Invalid datetime format, Naive datetime is disallowed",
"msg": "Input should have timezone info",
"input": "2020-11-10T12:42:39.442973",
"ctx": {"error": {}},
}
]
},
Expand All @@ -2241,11 +2257,11 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, test_client, session
{
"detail": [
{
"type": "value_error",
"type": "datetime_from_date_parsing",
"loc": ["body", "end_date"],
"msg": "Value error, Invalid isoformat string: '2020-11-10T12:4po'",
"msg": "Input should be a valid datetime or date, unexpected extra characters at the end of the input",
"input": "2020-11-10T12:4po",
"ctx": {"error": {}},
"ctx": {"error": "unexpected extra characters at the end of the input"},
}
]
},
Expand All @@ -2255,11 +2271,10 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, test_client, session
{
"detail": [
{
"type": "value_error",
"type": "timezone_aware",
"loc": ["body", "start_date"],
"msg": "Value error, Invalid datetime format, Naive datetime is disallowed",
"msg": "Input should have timezone info",
"input": "2020-11-10T12:42:39.442973",
"ctx": {"error": {}},
}
]
},
Expand All @@ -2269,11 +2284,11 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, test_client, session
{
"detail": [
{
"type": "value_error",
"type": "datetime_from_date_parsing",
"loc": ["body", "start_date"],
"msg": "Value error, Invalid isoformat string: '2020-11-10T12:4po'",
"msg": "Input should be a valid datetime or date, unexpected extra characters at the end of the input",
"input": "2020-11-10T12:4po",
"ctx": {"error": {}},
"ctx": {"error": "unexpected extra characters at the end of the input"},
}
]
},
Expand Down

0 comments on commit 8cd608c

Please sign in to comment.