From 57c79281663d8831babae696c6aade98c69a98f5 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 27 Feb 2024 21:45:34 +0800 Subject: [PATCH] Implement a way to forward warnings to db --- airflow/dag_processing/processor.py | 35 ++++++++++++------- airflow/datasets/__init__.py | 4 ++- airflow/exceptions.py | 7 ++++ airflow/models/dagbag.py | 21 ++++++++--- airflow/models/dagwarning.py | 24 ++++++++++--- .../endpoints/test_dag_warning_endpoint.py | 4 +-- tests/models/test_dagwarning.py | 4 +-- 7 files changed, 71 insertions(+), 28 deletions(-) diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index ce4c51552ba43..f7aadb8330543 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -645,26 +645,34 @@ def update_import_errors( session.commit() - @provide_session - def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION): + def _convert_user_code_warnings(self, *, dagbag: DagBag, session: Session) -> Iterator[DagWarning]: + """Convert collected DagUserCodeWarning instances to DagWarning in db.""" + for warning in dagbag.user_code_warnings: + yield DagWarning( + source_loc=f"{warning.filename}:{warning.lineno}", + warning_type=DagWarningType.AUTH_IN_DATASET_URI, + message=str(warning.message), + ) + + def _validate_task_pools(self, *, dagbag: DagBag, session: Session) -> Iterator[DagWarning]: """Validate and raise exception if any task in a dag is using a non-existent pool.""" from airflow.models.pool import Pool - def check_pools(dag): + def check_pools(dag: DAG) -> Iterator[DagWarning]: task_pools = {task.pool for task in dag.tasks} - nonexistent_pools = task_pools - pools - if nonexistent_pools: - return f"Dag '{dag.dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}" + if not (nonexistent_pools := task_pools - pools): + return + yield DagWarning( + dag_id=dag.dag_id, + warning_type=DagWarningType.NONEXISTENT_POOL, + message=f"DAG '{dag.dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}", + ) pools = {p.pool for p in Pool.get_pools(session)} for dag in dagbag.dags.values(): - message = check_pools(dag) - if message: - self.dag_warnings.add(DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) + yield from check_pools(dag) for subdag in dag.subdags: - message = check_pools(subdag) - if message: - self.dag_warnings.add(DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) + yield from check_pools(subdag) def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None: """ @@ -677,7 +685,8 @@ def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None: :param session: session for ORM operations :param dagbag: DagBag containing DAGs with configuration warnings """ - self._validate_task_pools(dagbag=dagbag) + self.dag_warnings.update(self._convert_user_code_warnings(dagbag=dagbag, session=session)) + self.dag_warnings.update(self._validate_task_pools(dagbag=dagbag, session=session)) stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dagbag.dags)).all()) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 4f8d587727f14..71dd1e31ae393 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -24,6 +24,8 @@ import attr +from airflow.exceptions import DagUserCodeWarning + if TYPE_CHECKING: from urllib.parse import SplitResult @@ -63,7 +65,7 @@ def _sanitize_uri(uri: str) -> str: warnings.warn( "A dataset URI should not contain auth info (e.g. username or " "password). It has been automatically dropped.", - UserWarning, + DagUserCodeWarning, stacklevel=3, ) if parsed.query: diff --git a/airflow/exceptions.py b/airflow/exceptions.py index f8ab14db56d32..2c9b4f2089987 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -424,6 +424,13 @@ class AirflowProviderDeprecationWarning(DeprecationWarning): "Indicates the provider version that started raising this deprecation warning" +class DagUserCodeWarning(UserWarning): + """Issued for user code in DAG file that should be warned about. + + This will be collected by the DAG processor converted into a DagWarning entry. + """ + + class DeserializingResultError(ValueError): """Raised when an error is encountered while a pickling library deserializes a pickle file.""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ce9bf5587be80..a649248cf8377 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -41,6 +41,7 @@ AirflowClusterPolicyViolation, AirflowDagCycleException, AirflowDagDuplicatedIdException, + DagUserCodeWarning, RemovedInAirflow3Warning, ) from airflow.stats import Stats @@ -134,6 +135,7 @@ def __init__( # the file's last modified timestamp when we last read it self.file_last_changed: dict[str, datetime] = {} self.import_errors: dict[str, str] = {} + self.user_code_warnings: list[warnings.WarningMessage] = [] self.has_logged = False self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True @@ -343,7 +345,10 @@ def parse(mod_name, filepath): spec = importlib.util.spec_from_loader(mod_name, loader) new_module = importlib.util.module_from_spec(spec) sys.modules[spec.name] = new_module - loader.exec_module(new_module) + with warnings.catch_warnings(record=True) as code_warnings: + warnings.simplefilter("default", category=DagUserCodeWarning) + loader.exec_module(new_module) + self.user_code_warnings.extend(code_warnings) return [new_module] except Exception as e: DagContext.autoregistered_dags.clear() @@ -405,13 +410,16 @@ def _load_modules_from_zip(self, filepath, safe_mode): del sys.modules[mod_name] DagContext.current_autoregister_module_name = mod_name + fileloc = os.path.join(filepath, zip_info.filename) try: sys.path.insert(0, filepath) - current_module = importlib.import_module(mod_name) + with warnings.catch_warnings(record=True) as code_warnings: + warnings.simplefilter("default", category=DagUserCodeWarning) + current_module = importlib.import_module(mod_name) + self.user_code_warnings.extend(code_warnings) mods.append(current_module) except Exception as e: DagContext.autoregistered_dags.clear() - fileloc = os.path.join(filepath, zip_info.filename) self.log.exception("Failed to import: %s", fileloc) if self.dagbag_import_error_tracebacks: self.import_errors[fileloc] = traceback.format_exc( @@ -439,8 +447,11 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): for dag, mod in top_level_dags: dag.fileloc = mod.__file__ try: - dag.validate() - self.bag_dag(dag=dag, root_dag=dag) + with warnings.catch_warnings(record=True) as code_warnings: + warnings.simplefilter("default", category=DagUserCodeWarning) + dag.validate() + self.bag_dag(dag=dag, root_dag=dag) + self.user_code_warnings.extend(code_warnings) except AirflowClusterPolicySkipDag: pass except Exception as e: diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py index 789fe0172784b..c88f9da5a2043 100644 --- a/airflow/models/dagwarning.py +++ b/airflow/models/dagwarning.py @@ -20,7 +20,7 @@ from enum import Enum from typing import TYPE_CHECKING -from sqlalchemy import Column, ForeignKeyConstraint, String, Text, delete, false, select +from sqlalchemy import Column, ForeignKeyConstraint, Index, String, Text, delete, false, select from airflow.api_internal.internal_api_call import internal_api_call from airflow.models.base import Base, StringID @@ -42,13 +42,15 @@ class DagWarning(Base): when parsing DAG and displayed on the Webserver in a flash message. """ - dag_id = Column(StringID(), primary_key=True) - warning_type = Column(String(50), primary_key=True) + dag_id = Column(StringID()) + source_loc = Column(String(2000)) + warning_type = Column(String(50)) message = Column(Text, nullable=False) timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow) __tablename__ = "dag_warning" __table_args__ = ( + Index("idx_dag_loc_type_unique", dag_id, source_loc, warning_type, unique=True), ForeignKeyConstraint( ("dag_id",), ["dag.dag_id"], @@ -57,10 +59,21 @@ class DagWarning(Base): ), ) - def __init__(self, dag_id: str, error_type: str, message: str, **kwargs): + def __init__( + self, + *, + dag_id: str | None = None, + source_loc: str | None = None, + warning_type: str, + message: str, + **kwargs, + ) -> None: super().__init__(**kwargs) + if dag_id is None and source_loc is None: + raise TypeError("must provide either dag_id or source_loc") self.dag_id = dag_id - self.warning_type = DagWarningType(error_type).value # make sure valid type + self.source_loc = source_loc + self.warning_type = DagWarningType(warning_type).value # make sure valid type self.message = message def __eq__(self, other) -> bool: @@ -104,3 +117,4 @@ class DagWarningType(str, Enum): """ NONEXISTENT_POOL = "non-existent pool" + AUTH_IN_DATASET_URI = "auth in dataset URI" diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 9310956d24f63..cc72a55986ae3 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -88,8 +88,8 @@ def setup_method(self): session.add(DagModel(dag_id="dag1")) session.add(DagModel(dag_id="dag2")) session.add(DagModel(dag_id="dag3")) - session.add(DagWarning("dag1", "non-existent pool", "test message")) - session.add(DagWarning("dag2", "non-existent pool", "test message")) + session.add(DagWarning(dag_id="dag1", warning_type="non-existent pool", message="test message")) + session.add(DagWarning(dag_id="dag2", warning_type="non-existent pool", message="test message")) session.commit() def test_response_one(self): diff --git a/tests/models/test_dagwarning.py b/tests/models/test_dagwarning.py index 58d6e5c752c64..1cc215cca7901 100644 --- a/tests/models/test_dagwarning.py +++ b/tests/models/test_dagwarning.py @@ -44,8 +44,8 @@ def test_purge_inactive_dag_warnings(self, session): session.commit() dag_warnings = [ - DagWarning("dag_1", "non-existent pool", "non-existent pool"), - DagWarning("dag_2", "non-existent pool", "non-existent pool"), + DagWarning(dag_id="dag_1", warning_type="non-existent pool", message="non-existent pool"), + DagWarning(dag_id="dag_2", warning_type="non-existent pool", message="non-existent pool"), ] session.add_all(dag_warnings) session.commit()