diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 76d79808500ea..bd2259802d8bb 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -17,9 +17,10 @@ import warnings from contextlib import closing from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Type, Union import sqlparse +from packaging.version import Version from sqlalchemy import create_engine from typing_extensions import Protocol @@ -27,6 +28,7 @@ from airflow.hooks.base import BaseHook from airflow.providers_manager import ProvidersManager from airflow.utils.module_loading import import_string +from airflow.version import version if TYPE_CHECKING: from sqlalchemy.engine import CursorResult @@ -76,7 +78,21 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any: """ -class DbApiHook(BaseHook): +# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below +# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise +# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to +# accept the new Hooks as not derived from the original DbApiHook +if Version(version) < Version('2.4'): + try: + from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook + except ImportError: + # just in case we have a problem with circular import + BaseForDbApiHook: Type[BaseHook] = BaseHook # type: ignore[no-redef] +else: + BaseForDbApiHook: Type[BaseHook] = BaseHook # type: ignore[no-redef] + + +class DbApiHook(BaseForDbApiHook): """ Abstract base class for sql hooks. diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index a44fa57e075a7..e957e264f3b8a 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -23,10 +23,19 @@ import pytest +from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook +class DbApiHookInProvider(DbApiHook): + conn_name_attr = 'test_conn_id' + + +class NonDbApiHook(BaseHook): + pass + + class TestDbApiHook(unittest.TestCase): def setUp(self): super().setUp() @@ -390,3 +399,14 @@ def test_run_no_queries(self): with pytest.raises(ValueError) as err: self.db_hook.run(sql=[]) assert err.value.args[0] == "List of SQL statements is empty" + + def test_instance_check_works_for_provider_derived_hook(self): + assert isinstance(DbApiHookInProvider(), DbApiHook) + + def test_instance_check_works_for_non_db_api_hook(self): + assert not isinstance(NonDbApiHook(), DbApiHook) + + def test_instance_check_works_for_legacy_db_api_hook(self): + from airflow.hooks.dbapi import DbApiHook as LegacyDbApiHook + + assert isinstance(DbApiHookInProvider(), LegacyDbApiHook)