Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
import warnings
from contextlib import closing
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Type, Union

import sqlparse
from packaging.version import Version
from sqlalchemy import create_engine
from typing_extensions import Protocol

from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers_manager import ProvidersManager
from airflow.utils.module_loading import import_string
from airflow.version import version

if TYPE_CHECKING:
from sqlalchemy.engine import CursorResult
Expand Down Expand Up @@ -76,7 +78,21 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any:
"""


class DbApiHook(BaseHook):
# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below
# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise
# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to
# accept the new Hooks as not derived from the original DbApiHook
if Version(version) < Version('2.4'):
try:
from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
except ImportError:
# just in case we have a problem with circular import
BaseForDbApiHook: Type[BaseHook] = BaseHook # type: ignore[no-redef]
else:
BaseForDbApiHook: Type[BaseHook] = BaseHook # type: ignore[no-redef]


class DbApiHook(BaseForDbApiHook):
"""
Abstract base class for sql hooks.

Expand Down
20 changes: 20 additions & 0 deletions tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@

import pytest

from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook


class DbApiHookInProvider(DbApiHook):
conn_name_attr = 'test_conn_id'


class NonDbApiHook(BaseHook):
pass


class TestDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -390,3 +399,14 @@ def test_run_no_queries(self):
with pytest.raises(ValueError) as err:
self.db_hook.run(sql=[])
assert err.value.args[0] == "List of SQL statements is empty"

def test_instance_check_works_for_provider_derived_hook(self):
assert isinstance(DbApiHookInProvider(), DbApiHook)

def test_instance_check_works_for_non_db_api_hook(self):
assert not isinstance(NonDbApiHook(), DbApiHook)

def test_instance_check_works_for_legacy_db_api_hook(self):
from airflow.hooks.dbapi import DbApiHook as LegacyDbApiHook

assert isinstance(DbApiHookInProvider(), LegacyDbApiHook)