Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,26 +645,34 @@ def update_import_errors(

session.commit()

@provide_session
def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION):
def _convert_user_code_warnings(self, *, dagbag: DagBag, session: Session) -> Iterator[DagWarning]:
"""Convert collected DagUserCodeWarning instances to DagWarning in db."""
for warning in dagbag.user_code_warnings:
yield DagWarning(
source_loc=f"{warning.filename}:{warning.lineno}",
warning_type=DagWarningType.AUTH_IN_DATASET_URI,
message=str(warning.message),
)

def _validate_task_pools(self, *, dagbag: DagBag, session: Session) -> Iterator[DagWarning]:
"""Validate and raise exception if any task in a dag is using a non-existent pool."""
from airflow.models.pool import Pool

def check_pools(dag):
def check_pools(dag: DAG) -> Iterator[DagWarning]:
task_pools = {task.pool for task in dag.tasks}
nonexistent_pools = task_pools - pools
if nonexistent_pools:
return f"Dag '{dag.dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}"
if not (nonexistent_pools := task_pools - pools):
return
yield DagWarning(
dag_id=dag.dag_id,
warning_type=DagWarningType.NONEXISTENT_POOL,
message=f"DAG '{dag.dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}",
)

pools = {p.pool for p in Pool.get_pools(session)}
for dag in dagbag.dags.values():
message = check_pools(dag)
if message:
self.dag_warnings.add(DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message))
yield from check_pools(dag)
for subdag in dag.subdags:
message = check_pools(subdag)
if message:
self.dag_warnings.add(DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message))
yield from check_pools(subdag)

def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None:
"""
Expand All @@ -677,7 +685,8 @@ def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None:
:param session: session for ORM operations
:param dagbag: DagBag containing DAGs with configuration warnings
"""
self._validate_task_pools(dagbag=dagbag)
self.dag_warnings.update(self._convert_user_code_warnings(dagbag=dagbag, session=session))
self.dag_warnings.update(self._validate_task_pools(dagbag=dagbag, session=session))

stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dagbag.dags)).all())

Expand Down
4 changes: 3 additions & 1 deletion airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import attr

from airflow.exceptions import DagUserCodeWarning

if TYPE_CHECKING:
from urllib.parse import SplitResult

Expand Down Expand Up @@ -63,7 +65,7 @@ def _sanitize_uri(uri: str) -> str:
warnings.warn(
"A dataset URI should not contain auth info (e.g. username or "
"password). It has been automatically dropped.",
UserWarning,
DagUserCodeWarning,
stacklevel=3,
)
if parsed.query:
Expand Down
7 changes: 7 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ class AirflowProviderDeprecationWarning(DeprecationWarning):
"Indicates the provider version that started raising this deprecation warning"


class DagUserCodeWarning(UserWarning):
"""Issued for user code in DAG file that should be warned about.

This will be collected by the DAG processor converted into a DagWarning entry.
"""


class DeserializingResultError(ValueError):
"""Raised when an error is encountered while a pickling library deserializes a pickle file."""

Expand Down
21 changes: 16 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
DagUserCodeWarning,
RemovedInAirflow3Warning,
)
from airflow.stats import Stats
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
# the file's last modified timestamp when we last read it
self.file_last_changed: dict[str, datetime] = {}
self.import_errors: dict[str, str] = {}
self.user_code_warnings: list[warnings.WarningMessage] = []
self.has_logged = False
self.read_dags_from_db = read_dags_from_db
# Only used by read_dags_from_db=True
Expand Down Expand Up @@ -343,7 +345,10 @@ def parse(mod_name, filepath):
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
with warnings.catch_warnings(record=True) as code_warnings:
warnings.simplefilter("default", category=DagUserCodeWarning)
loader.exec_module(new_module)
self.user_code_warnings.extend(code_warnings)
return [new_module]
except Exception as e:
DagContext.autoregistered_dags.clear()
Expand Down Expand Up @@ -405,13 +410,16 @@ def _load_modules_from_zip(self, filepath, safe_mode):
del sys.modules[mod_name]

DagContext.current_autoregister_module_name = mod_name
fileloc = os.path.join(filepath, zip_info.filename)
try:
sys.path.insert(0, filepath)
current_module = importlib.import_module(mod_name)
with warnings.catch_warnings(record=True) as code_warnings:
warnings.simplefilter("default", category=DagUserCodeWarning)
current_module = importlib.import_module(mod_name)
self.user_code_warnings.extend(code_warnings)
mods.append(current_module)
except Exception as e:
DagContext.autoregistered_dags.clear()
fileloc = os.path.join(filepath, zip_info.filename)
self.log.exception("Failed to import: %s", fileloc)
if self.dagbag_import_error_tracebacks:
self.import_errors[fileloc] = traceback.format_exc(
Expand Down Expand Up @@ -439,8 +447,11 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
for dag, mod in top_level_dags:
dag.fileloc = mod.__file__
try:
dag.validate()
self.bag_dag(dag=dag, root_dag=dag)
with warnings.catch_warnings(record=True) as code_warnings:
warnings.simplefilter("default", category=DagUserCodeWarning)
dag.validate()
self.bag_dag(dag=dag, root_dag=dag)
self.user_code_warnings.extend(code_warnings)
except AirflowClusterPolicySkipDag:
pass
except Exception as e:
Expand Down
24 changes: 19 additions & 5 deletions airflow/models/dagwarning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from enum import Enum
from typing import TYPE_CHECKING

from sqlalchemy import Column, ForeignKeyConstraint, String, Text, delete, false, select
from sqlalchemy import Column, ForeignKeyConstraint, Index, String, Text, delete, false, select

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base, StringID
Expand All @@ -42,13 +42,15 @@ class DagWarning(Base):
when parsing DAG and displayed on the Webserver in a flash message.
"""

dag_id = Column(StringID(), primary_key=True)
warning_type = Column(String(50), primary_key=True)
dag_id = Column(StringID())
source_loc = Column(String(2000))
warning_type = Column(String(50))
message = Column(Text, nullable=False)
timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow)

__tablename__ = "dag_warning"
__table_args__ = (
Index("idx_dag_loc_type_unique", dag_id, source_loc, warning_type, unique=True),
ForeignKeyConstraint(
("dag_id",),
["dag.dag_id"],
Expand All @@ -57,10 +59,21 @@ class DagWarning(Base):
),
)

def __init__(self, dag_id: str, error_type: str, message: str, **kwargs):
def __init__(
self,
*,
dag_id: str | None = None,
source_loc: str | None = None,
warning_type: str,
message: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
if dag_id is None and source_loc is None:
raise TypeError("must provide either dag_id or source_loc")
self.dag_id = dag_id
self.warning_type = DagWarningType(error_type).value # make sure valid type
self.source_loc = source_loc
self.warning_type = DagWarningType(warning_type).value # make sure valid type
self.message = message

def __eq__(self, other) -> bool:
Expand Down Expand Up @@ -104,3 +117,4 @@ class DagWarningType(str, Enum):
"""

NONEXISTENT_POOL = "non-existent pool"
AUTH_IN_DATASET_URI = "auth in dataset URI"
4 changes: 2 additions & 2 deletions tests/api_connexion/endpoints/test_dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def setup_method(self):
session.add(DagModel(dag_id="dag1"))
session.add(DagModel(dag_id="dag2"))
session.add(DagModel(dag_id="dag3"))
session.add(DagWarning("dag1", "non-existent pool", "test message"))
session.add(DagWarning("dag2", "non-existent pool", "test message"))
session.add(DagWarning(dag_id="dag1", warning_type="non-existent pool", message="test message"))
session.add(DagWarning(dag_id="dag2", warning_type="non-existent pool", message="test message"))
session.commit()

def test_response_one(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_dagwarning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_purge_inactive_dag_warnings(self, session):
session.commit()

dag_warnings = [
DagWarning("dag_1", "non-existent pool", "non-existent pool"),
DagWarning("dag_2", "non-existent pool", "non-existent pool"),
DagWarning(dag_id="dag_1", warning_type="non-existent pool", message="non-existent pool"),
DagWarning(dag_id="dag_2", warning_type="non-existent pool", message="non-existent pool"),
]
session.add_all(dag_warnings)
session.commit()
Expand Down