-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Unify DbApiHook.run() method with the methods which override it #23971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f614f47
368420c
a184b34
f9e1461
6fd8745
dcc2deb
c3cdcd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,8 +17,9 @@ | |
| import warnings | ||
| from contextlib import closing | ||
| from datetime import datetime | ||
| from typing import Any, Optional | ||
| from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union | ||
|
|
||
| import sqlparse | ||
| from sqlalchemy import create_engine | ||
| from typing_extensions import Protocol | ||
|
|
||
|
|
@@ -27,6 +28,17 @@ | |
| from airflow.providers_manager import ProvidersManager | ||
| from airflow.utils.module_loading import import_string | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sqlalchemy.engine import CursorResult | ||
|
|
||
|
|
||
| def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]: | ||
| """Handler for DbApiHook.run() to return results""" | ||
| if cursor.returns_rows: | ||
| return cursor.fetchall() | ||
| else: | ||
| return None | ||
|
|
||
|
|
||
| def _backported_get_hook(connection, *, hook_params=None): | ||
| """Return hook based on conn_type | ||
|
|
@@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None): | |
| cur.execute(sql) | ||
| return cur.fetchone() | ||
|
|
||
| def run(self, sql, autocommit=False, parameters=None, handler=None): | ||
| @staticmethod | ||
| def strip_sql_string(sql: str) -> str: | ||
| return sql.strip().rstrip(';') | ||
|
|
||
| @staticmethod | ||
| def split_sql_string(sql: str) -> List[str]: | ||
| """ | ||
| Splits string into multiple SQL expressions | ||
|
|
||
| :param sql: SQL string potentially consisting of multiple expressions | ||
| :return: list of individual expressions | ||
| """ | ||
| splits = sqlparse.split(sqlparse.format(sql, strip_comments=True)) | ||
| statements = [s.rstrip(';') for s in splits if s.endswith(';')] | ||
| return statements | ||
|
|
||
| def run( | ||
| self, | ||
| sql: Union[str, Iterable[str]], | ||
| autocommit: bool = False, | ||
| parameters: Optional[Union[Iterable, Mapping]] = None, | ||
| handler: Optional[Callable] = None, | ||
| split_statements: bool = False, | ||
| return_last: bool = True, | ||
| ) -> Optional[Union[Any, List[Any]]]: | ||
| """ | ||
| Runs a command or a list of commands. Pass a list of sql | ||
| statements to the sql parameter to get them to execute | ||
|
|
@@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): | |
| before executing the query. | ||
| :param parameters: The parameters to render the SQL query with. | ||
| :param handler: The result handler which is called with the result of each statement. | ||
| :return: query results if handler was provided. | ||
| :param split_statements: Whether to split a single SQL string into statements and run separately | ||
| :param return_last: Whether to return result for only last statement or for all after split | ||
| :return: return only result of the ALL SQL expressions if handler was provided. | ||
| """ | ||
| scalar = isinstance(sql, str) | ||
| if scalar: | ||
| sql = [sql] | ||
| scalar_return_last = isinstance(sql, str) and return_last | ||
| if isinstance(sql, str): | ||
potiuk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if split_statements: | ||
| sql = self.split_sql_string(sql) | ||
| else: | ||
| sql = [self.strip_sql_string(sql)] | ||
|
|
||
| if sql: | ||
| self.log.debug("Executing %d statements", len(sql)) | ||
| self.log.debug("Executing following statements against DB: %s", list(sql)) | ||
| else: | ||
| raise ValueError("List of SQL statements is empty") | ||
|
|
||
|
|
@@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): | |
| results = [] | ||
| for sql_statement in sql: | ||
| self._run_command(cur, sql_statement, parameters) | ||
|
|
||
| if handler is not None: | ||
| result = handler(cur) | ||
| results.append(result) | ||
|
|
||
| # If autocommit was set to False for db that supports autocommit, | ||
| # or if db does not supports autocommit, we do a manual commit. | ||
| # If autocommit was set to False or db does not support autocommit, we do a manual commit. | ||
| if not self.get_autocommit(conn): | ||
| conn.commit() | ||
|
|
||
| if handler is None: | ||
| return None | ||
|
|
||
| if scalar: | ||
| return results[0] | ||
|
|
||
| return results | ||
| elif scalar_return_last: | ||
| return results[-1] | ||
| else: | ||
|
||
| return results | ||
|
|
||
| def _run_command(self, cur, sql_statement, parameters): | ||
| """Runs a statement using an already open cursor.""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.