Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class AirflowRescheduleException(AirflowException):
:param reschedule_date: The date when the task should be rescheduled
"""

def __init__(self, reschedule_date):
def __init__(self, reschedule_date, poke_number):
super().__init__()
self.reschedule_date = reschedule_date
self.poke_number = poke_number


class InvalidStatsNameException(AirflowException):
Expand Down
29 changes: 19 additions & 10 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,9 @@ def try_number(self):
If the TaskInstance is currently running, this will match the column in the
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
if self.state == State.RUNNING:
# This is designed so that task logs end up in the right file. Also, it should return
# the right try_number when the task state is not one of the finished states
if self.state in State.unfinished:
return self._try_number
return self._try_number + 1

Expand Down Expand Up @@ -1295,6 +1296,9 @@ def check_and_change_state_before_execution(
# If the task continues after being deferred (next_method is set), use the original start_date
self.start_date = self.start_date if self.next_method else timezone.utcnow()
if self.state == State.UP_FOR_RESCHEDULE:
# FIXME: unfortunately the state is always queued in this method,
# we should find a way to get the state before queuing the TI state,
# otherwise this block will be never reached as it is the case now.
task_reschedule: TR = TR.query_for_task_instance(self, session=session).first()
if task_reschedule:
self.start_date = task_reschedule.start_date
Expand Down Expand Up @@ -1757,7 +1761,11 @@ def dry_run(self) -> None:

@provide_session
def _handle_reschedule(
self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION
self,
actual_start_date,
reschedule_exception: AirflowRescheduleException,
test_mode=False,
session=NEW_SESSION,
):
# Don't record reschedule request in test mode
if test_mode:
Expand All @@ -1783,13 +1791,14 @@ def _handle_reschedule(
# Log reschedule request
session.add(
TaskReschedule(
self.task,
self.run_id,
self._try_number,
actual_start_date,
self.end_date,
reschedule_exception.reschedule_date,
self.map_index,
task=self.task,
run_id=self.run_id,
try_number=self._try_number,
poke_number=reschedule_exception.poke_number,
start_date=actual_start_date,
end_date=self.end_date,
reschedule_date=reschedule_exception.reschedule_date,
map_index=self.map_index,
)
)

Expand Down
12 changes: 9 additions & 3 deletions airflow/models/taskreschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlalchemy.orm import Query, Session, relationship

from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

Expand All @@ -45,6 +46,7 @@ class TaskReschedule(Base):
run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
map_index = Column(Integer, nullable=False, server_default=text("-1"))
try_number = Column(Integer, nullable=False)
poke_number = Column(Integer, nullable=False)
start_date = Column(UtcDateTime, nullable=False)
end_date = Column(UtcDateTime, nullable=False)
duration = Column(Integer, nullable=False)
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
task: Operator,
run_id: str,
try_number: int,
poke_number: int,
start_date: datetime.datetime,
end_date: datetime.datetime,
reschedule_date: datetime.datetime,
Expand All @@ -89,6 +92,7 @@ def __init__(
self.run_id = run_id
self.map_index = map_index
self.try_number = try_number
self.poke_number = poke_number
self.start_date = start_date
self.end_date = end_date
self.reschedule_date = reschedule_date
Expand All @@ -97,7 +101,7 @@ def __init__(
@staticmethod
@provide_session
def query_for_task_instance(
task_instance: TaskInstance,
task_instance: TaskInstance | TaskInstancePydantic,
descending: bool = False,
session: Session = NEW_SESSION,
try_number: int | None = None,
Expand All @@ -112,7 +116,9 @@ def query_for_task_instance(
looks for the same try_number of the given task_instance.
"""
if try_number is None:
try_number = task_instance.try_number
# Since the TI state (up_for_reschedule) is in one of the not finished states,
# we should add 1 to its try_number to get the TaskReschedule try_number
try_number = task_instance.try_number + 1

TR = TaskReschedule
qry = session.query(TR).filter(
Expand All @@ -130,7 +136,7 @@ def query_for_task_instance(
@staticmethod
@provide_session
def find_for_task_instance(
task_instance: TaskInstance,
task_instance: TaskInstance | TaskInstancePydantic,
session: Session = NEW_SESSION,
try_number: int | None = None,
) -> list[TaskReschedule]:
Expand Down
31 changes: 19 additions & 12 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,25 @@ def poke(self, context: Context) -> bool | PokeReturnValue:

def execute(self, context: Context) -> Any:
started_at: datetime.datetime | float
poke_number: int

if self.reschedule:

# If reschedule, use the start date of the first try (first try can be either the very
# first execution of the task, or the first execution after the task was cleared.)
first_try_number = context["ti"].max_tries - self.retries + 1
task_reschedules = TaskReschedule.find_for_task_instance(
context["ti"], try_number=first_try_number
)
# we should add 1 to match the TaskReschedule try_number
task_reschedules = None
if context["ti"].max_tries is not None and self.retries is not None:
first_try_number = context["ti"].max_tries - self.retries + 1
task_reschedules = TaskReschedule.find_for_task_instance(
context["ti"], try_number=first_try_number
)
if not task_reschedules:
start_date = timezone.utcnow()
poke_number = 1
else:
start_date = task_reschedules[0].start_date
poke_number = task_reschedules[-1].poke_number + 1
started_at = start_date

def run_duration() -> float:
Expand All @@ -199,16 +205,17 @@ def run_duration() -> float:

else:
started_at = start_monotonic = time.monotonic()
poke_number = 1

def run_duration() -> float:
return time.monotonic() - start_monotonic

try_number = 1
log_dag_id = self.dag.dag_id if self.has_dag() else ""

xcom_value = None
while True:
try:
self.log.info("Poke number {}", poke_number)
poke_return = self.poke(context)
except (
AirflowSensorTimeout,
Expand Down Expand Up @@ -241,34 +248,34 @@ def run_duration() -> float:
else:
raise AirflowSensorTimeout(message)
if self.reschedule:
next_poke_interval = self._get_next_poke_interval(started_at, run_duration, try_number)
next_poke_interval = self._get_next_poke_interval(started_at, run_duration, poke_number)
reschedule_date = timezone.utcnow() + timedelta(seconds=next_poke_interval)
if _is_metadatabase_mysql() and reschedule_date > _MYSQL_TIMESTAMP_MAX:
raise AirflowSensorTimeout(
f"Cannot reschedule DAG {log_dag_id} to {reschedule_date.isoformat()} "
f"since it is over MySQL's TIMESTAMP storage limit."
)
raise AirflowRescheduleException(reschedule_date)
raise AirflowRescheduleException(reschedule_date=reschedule_date, poke_number=poke_number)
else:
time.sleep(self._get_next_poke_interval(started_at, run_duration, try_number))
try_number += 1
time.sleep(self._get_next_poke_interval(started_at, run_duration, poke_number))
poke_number += 1
self.log.info("Success criteria met. Exiting.")
return xcom_value

def _get_next_poke_interval(
self,
started_at: datetime.datetime | float,
run_duration: Callable[[], float],
try_number: int,
poke_number: int,
) -> float:
"""Using the similar logic which is used for exponential backoff retry delay for operators."""
if not self.exponential_backoff:
return self.poke_interval

min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
min_backoff = int(self.poke_interval * (2 ** (poke_number - 2)))

run_hash = int(
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{poke_number}".encode()).hexdigest(),
16,
)
modded_hash = min_backoff + run_hash % min_backoff
Expand Down