Skip to content

[AIP-49] OpenTelemetry Traces for Apache Airflow Part 2 #40802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 188 additions & 114 deletions airflow/dag_processing/manager.py

Large diffs are not rendered by default.

75 changes: 72 additions & 3 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.traces import NO_TRACE_ID
from airflow.traces.tracer import Trace, gen_context, span
from airflow.traces.utils import gen_span_id_from_ti_key, gen_trace_id
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.task_context_logger import TaskContextLogger
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -211,6 +214,7 @@ def sync(self) -> None:
Executors should override this to perform gather statuses.
"""

@span
def heartbeat(self) -> None:
"""Heartbeat sent to trigger new jobs."""
if not self.parallelism:
Expand All @@ -228,6 +232,17 @@ def heartbeat(self) -> None:
else:
self.log.debug("%s open slots", open_slots)

span = Trace.get_current_span()
if span.is_recording():
span.add_event(
name="executor",
attributes={
"executor.open_slots": open_slots,
"executor.queued_tasks": num_queued_tasks,
"executor.running_tasks": num_running_tasks,
},
)

Stats.gauge(
"executor.open_slots", value=open_slots, tags={"status": "open", "name": self.__class__.__name__}
)
Expand Down Expand Up @@ -260,12 +275,14 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTa
reverse=True,
)

@span
def trigger_tasks(self, open_slots: int) -> None:
"""
Initiate async execution of the queued tasks, up to the number of available slots.

:param open_slots: Number of open slots
"""
span = Trace.get_current_span()
sorted_queue = self.order_queued_tasks_by_priority()
task_tuples = []

Expand Down Expand Up @@ -302,15 +319,40 @@ def trigger_tasks(self, open_slots: int) -> None:
if key in self.attempts:
del self.attempts[key]
task_tuples.append((key, command, queue, ti.executor_config))
if span.is_recording():
span.add_event(
name="task to trigger",
attributes={"command": str(command), "conf": str(ti.executor_config)},
)

if task_tuples:
self._process_tasks(task_tuples)

@span
def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
for key, command, queue, executor_config in task_tuples:
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
task_instance = self.queued_tasks[key][3] # TaskInstance in fourth element
trace_id = int(gen_trace_id(task_instance.dag_run, as_int=True))
span_id = int(gen_span_id_from_ti_key(key, as_int=True))
links = [{"trace_id": trace_id, "span_id": span_id}]

# assuming that the span_id will very likely be unique inside the trace
with Trace.start_span(
span_name=f"{key.dag_id}.{key.task_id}",
component="BaseExecutor",
span_id=span_id,
links=links,
) as span:
span.set_attribute("dag_id", key.dag_id)
span.set_attribute("run_id", key.run_id)
span.set_attribute("task_id", key.task_id)
span.set_attribute("try_number", key.try_number)
span.set_attribute("command", str(command))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the value of putting the command in the span? It simply duplicates (in a format that is of no use as it's a single string) the dag_id, run_id, task_id etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the value of putting the command in the span? It simply duplicates (in a format that is of no use as it's a single string) the dag_id, run_id, task_id etc.

I did put the command as attribute, assuming that there may be something additional other than dag_id, run_id, task_id, etc. Due to the fact that I did not have too much of deep understanding of what the command can be, felt it was worth recording it as part of the span. If there's no real value of instrumenting the details of 'command' whatsoever, I'd say we can remove that instrumentation out of it for the next release.

span.set_attribute("queue", str(queue))
span.set_attribute("executor_config", str(executor_config))
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
Expand Down Expand Up @@ -338,6 +380,20 @@ def fail(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
trace_id = Trace.get_current_span().get_span_context().trace_id
if trace_id != NO_TRACE_ID:
span_id = int(gen_span_id_from_ti_key(key, as_int=True))
with Trace.start_span(
span_name="fail",
component="BaseExecutor",
parent_sc=gen_context(trace_id=trace_id, span_id=span_id),
) as span:
span.set_attribute("dag_id", key.dag_id)
span.set_attribute("run_id", key.run_id)
span.set_attribute("task_id", key.task_id)
span.set_attribute("try_number", key.try_number)
span.set_attribute("error", True)

self.change_state(key, TaskInstanceState.FAILED, info)

def success(self, key: TaskInstanceKey, info=None) -> None:
Expand All @@ -347,6 +403,19 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
trace_id = Trace.get_current_span().get_span_context().trace_id
if trace_id != NO_TRACE_ID:
span_id = int(gen_span_id_from_ti_key(key, as_int=True))
with Trace.start_span(
span_name="success",
component="BaseExecutor",
parent_sc=gen_context(trace_id=trace_id, span_id=span_id),
) as span:
span.set_attribute("dag_id", key.dag_id)
span.set_attribute("run_id", key.run_id)
span.set_attribute("task_id", key.task_id)
span.set_attribute("try_number", key.try_number - 1)

self.change_state(key, TaskInstanceState.SUCCESS, info)

def queued(self, key: TaskInstanceKey, info=None) -> None:
Expand Down
17 changes: 17 additions & 0 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.traces.tracer import Trace, span
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -77,6 +78,7 @@ def run(self):
setproctitle("airflow worker -- LocalExecutor")
return super().run()

@span
def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None:
"""
Execute command received and stores result state in queue.
Expand All @@ -98,6 +100,7 @@ def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None:
# Remove the command since the worker is done executing the task
setproctitle("airflow worker -- LocalExecutor")

@span
def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState:
try:
subprocess.check_call(command, close_fds=True)
Expand All @@ -106,6 +109,7 @@ def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState
self.log.error("Failed to execute task %s.", e)
return TaskInstanceState.FAILED

@span
def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState:
pid = os.fork()
if pid:
Expand Down Expand Up @@ -165,6 +169,7 @@ def __init__(
self.key: TaskInstanceKey = key
self.command: CommandType = command

@span
def do_work(self) -> None:
self.execute_work(key=self.key, command=self.command)

Expand All @@ -184,6 +189,7 @@ def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[Task
super().__init__(result_queue=result_queue)
self.task_queue = task_queue

@span
def do_work(self) -> None:
while True:
try:
Expand Down Expand Up @@ -244,6 +250,7 @@ def start(self) -> None:
self.executor.workers_used = 0
self.executor.workers_active = 0

@span
def execute_async(
self,
key: TaskInstanceKey,
Expand All @@ -262,6 +269,14 @@ def execute_async(
if TYPE_CHECKING:
assert self.executor.result_queue

span = Trace.get_current_span()
if span.is_recording():
span.set_attribute("dag_id", key.dag_id)
span.set_attribute("run_id", key.run_id)
span.set_attribute("task_id", key.task_id)
span.set_attribute("try_number", key.try_number - 1)
span.set_attribute("commands_to_run", str(command))

local_worker = LocalWorker(self.executor.result_queue, key=key, command=command)
self.executor.workers_used += 1
self.executor.workers_active += 1
Expand Down Expand Up @@ -311,6 +326,7 @@ def start(self) -> None:
for worker in self.executor.workers:
worker.start()

@span
def execute_async(
self,
key: TaskInstanceKey,
Expand Down Expand Up @@ -372,6 +388,7 @@ def start(self) -> None:

self.impl.start()

@span
def execute_async(
self,
key: TaskInstanceKey,
Expand Down
10 changes: 10 additions & 0 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TYPE_CHECKING, Any

from airflow.executors.base_executor import BaseExecutor
from airflow.traces.tracer import Trace, span

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(self):
super().__init__()
self.commands_to_run = []

@span
def execute_async(
self,
key: TaskInstanceKey,
Expand All @@ -69,6 +71,14 @@ def execute_async(
self.validate_airflow_tasks_run_command(command)
self.commands_to_run.append((key, command))

span = Trace.get_current_span()
if span.is_recording():
span.set_attribute("dag_id", key.dag_id)
span.set_attribute("run_id", key.run_id)
span.set_attribute("task_id", key.task_id)
span.set_attribute("try_number", key.try_number - 1)
span.set_attribute("commands_to_run", str(self.commands_to_run))

def sync(self) -> None:
for key, command in self.commands_to_run:
self.log.info("Executing command: %s", command)
Expand Down
104 changes: 58 additions & 46 deletions airflow/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.models.base import ID_LEN, Base
from airflow.serialization.pydantic.job import JobPydantic
from airflow.stats import Stats
from airflow.traces.tracer import Trace, span
from airflow.utils import timezone
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -199,52 +200,62 @@ def heartbeat(
:param session to use for saving the job
"""
previous_heartbeat = self.latest_heartbeat

try:
# This will cause it to load from the db
self._merge_from(Job._fetch_from_db(self, session))
previous_heartbeat = self.latest_heartbeat

if self.state == JobState.RESTARTING:
self.kill()

# Figure out how long to sleep for
sleep_for = 0
if self.latest_heartbeat:
seconds_remaining = (
self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds()
)
sleep_for = max(0, seconds_remaining)
sleep(sleep_for)

job = Job._update_heartbeat(job=self, session=session)
self._merge_from(job)
time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds()
health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate)
if time_since_last_heartbeat > health_check_threshold_value:
self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat)
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat

heartbeat_callback(session)
self.log.debug("[heartbeat]")
self.heartbeat_failed = False
except OperationalError:
Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1)
if not self.heartbeat_failed:
self.log.exception("%s heartbeat failed with error", self.__class__.__name__)
self.heartbeat_failed = True
if self.is_alive():
self.log.error(
"%s heartbeat failed with error. Scheduler may go into unhealthy state",
self.__class__.__name__,
)
else:
self.log.error(
"%s heartbeat failed with error. Scheduler is in unhealthy state", self.__class__.__name__
)
# We didn't manage to heartbeat, so make sure that the timestamp isn't updated
self.latest_heartbeat = previous_heartbeat
with Trace.start_span(span_name="heartbeat", component="Job") as span:
try:
span.set_attribute("heartbeat", str(self.latest_heartbeat))
# This will cause it to load from the db
self._merge_from(Job._fetch_from_db(self, session))
previous_heartbeat = self.latest_heartbeat

if self.state == JobState.RESTARTING:
self.kill()

# Figure out how long to sleep for
sleep_for = 0
if self.latest_heartbeat:
seconds_remaining = (
self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds()
)
sleep_for = max(0, seconds_remaining)
if span.is_recording():
span.add_event(name="sleep", attributes={"sleep_for": sleep_for})
sleep(sleep_for)

job = Job._update_heartbeat(job=self, session=session)
self._merge_from(job)
time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds()
health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate)
if time_since_last_heartbeat > health_check_threshold_value:
self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat)
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat

heartbeat_callback(session)
self.log.debug("[heartbeat]")
self.heartbeat_failed = False
except OperationalError:
Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1)
if not self.heartbeat_failed:
self.log.exception("%s heartbeat failed with error", self.__class__.__name__)
self.heartbeat_failed = True
msg = f"{self.__class__.__name__} heartbeat got an exception"
if span.is_recording():
span.add_event(name="error", attributes={"message": msg})
if self.is_alive():
self.log.error(
"%s heartbeat failed with error. Scheduler may go into unhealthy state",
self.__class__.__name__,
)
msg = f"{self.__class__.__name__} heartbeat failed with error. Scheduler may go into unhealthy state"
if span.is_recording():
span.add_event(name="error", attributes={"message": msg})
else:
msg = f"{self.__class__.__name__} heartbeat failed with error. Scheduler is in unhealthy state"
self.log.error(msg)
if span.is_recording():
span.add_event(name="error", attributes={"message": msg})
# We didn't manage to heartbeat, so make sure that the timestamp isn't updated
self.latest_heartbeat = previous_heartbeat

@provide_session
def prepare_for_execution(self, session: Session = NEW_SESSION):
Expand Down Expand Up @@ -448,6 +459,7 @@ def execute_job(job: Job, execute_callable: Callable[[], int | None]) -> int | N
return ret


@span
def perform_heartbeat(
job: Job, heartbeat_callback: Callable[[Session], None], only_if_necessary: bool
) -> None:
Expand Down
Loading