Skip to content

Commit

Permalink
feat(api): add disconnect method (#8341)
Browse files Browse the repository at this point in the history
## Description of changes

This adds a `disconnect` method to all backends. Previously we didn't do
this since the actual connection was often wrapped in SQLAlchemy and our
ability to terminate the connection was unclear.

The DB-API spec states that there should be a `close` method, and that
any subsequent operations on a given `Connection` should raise an error
after `close` is called.

This _mostly_ works.

Trino, Clickhouse, Impala, and BigQuery do not conform to the DB-API in
this way.  They have the `close` method but don't raise when you make a
subsequent call.

Spark is not a DB-API, but the `spark.sql.session.stop` method does the
right thing.

For the in-memory backends and Flink, this is a no-op as there is either 
nothing to disconnect or no exposed method to disconnect.



## Issues closed

Resolves #5940
  • Loading branch information
gforsyth committed Feb 14, 2024
1 parent 279357a commit 32665af
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,10 @@ def connect(self, *args, **kwargs) -> BaseBackend:
new_backend.reconnect()
return new_backend

@abc.abstractmethod
def disconnect(self) -> None:
"""Close the connection to the backend."""

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
"""Manipulate keyword arguments to `.connect` method."""
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,8 @@ def truncate_table(
).sql(self.dialect)
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass

def disconnect(self):
# This is part of the Python DB-API specification so should work for
# _most_ sqlglot backends
self.con.close()
3 changes: 3 additions & 0 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ def do_connect(

self.partition_column = partition_column

def disconnect(self) -> None:
self.client.close()

def _parse_project_and_dataset(self, dataset) -> tuple[str, str]:
if not dataset and not self.dataset:
raise ValueError("Unable to determine BigQuery dataset.")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def raw_sql(
self._log(query)
return self.con.query(query, external_data=external_data, **kwargs)

def close(self) -> None:
def disconnect(self) -> None:
"""Close ClickHouse connection."""
self.con.close()

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def do_connect(
)
super().do_connect(dictionary)

def disconnect(self) -> None:
pass

@property
def version(self):
return dask.__version__
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def do_connect(
for name, path in config.items():
self.register(path, table_name=name)

def disconnect(self) -> None:
pass

@contextlib.contextmanager
def _safe_raw_sql(self, sql: sge.Statement) -> Any:
yield self.raw_sql(sql).collect()
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def do_connect(self, table_env: TableEnvironment) -> None:
"""
self._table_env = table_env

def disconnect(self) -> None:
pass

def raw_sql(self, query: str) -> TableResult:
return self._table_env.execute_sql(query)

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def do_connect(
self.dictionary = dictionary or {}
self.schemas: MutableMapping[str, sch.Schema] = {}

def disconnect(self) -> None:
pass

def from_dataframe(
self,
df: pd.DataFrame,
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def do_connect(
for name, table in (tables or {}).items():
self._add_table(name, table)

def disconnect(self) -> None:
pass

@property
def version(self) -> str:
return pl.__version__
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def do_connect(self, session: SparkSession) -> None:
self._session.conf.set("spark.sql.session.timeZone", "UTC")
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")

def disconnect(self) -> None:
self._session.stop()

def _metadata(self, query: str):
cursor = self.raw_sql(query)
struct_dtype = PySparkType.to_ibis(cursor.query.schema)
Expand Down
25 changes: 25 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,3 +1460,28 @@ def test_list_databases_schemas(con_create_database_schema):
con_create_database_schema.drop_schema(schema, database=database)
finally:
con_create_database_schema.drop_database(database)


@pytest.mark.notyet(
["pandas", "dask", "polars", "datafusion"],
reason="this is a no-op for in-memory backends",
)
@pytest.mark.notyet(
["trino", "clickhouse", "impala", "bigquery", "flink"],
reason="Backend client does not conform to DB-API, subsequent op does not raise",
)
@pytest.mark.skip()
def test_close_connection(con):
if con.name == "pyspark":
# It would be great if there were a simple way to say "give me a new
# spark context" but I haven't found it.
pytest.skip("Closing spark context breaks subsequent tests")
new_con = getattr(ibis, con.name).connect(*con._con_args, **con._con_kwargs)

# Run any command that hits the backend
_ = new_con.list_tables()
new_con.disconnect()

# DB-API states that subsequent execution attempt should raise
with pytest.raises(Exception): # noqa:B017
new_con.list_tables()
3 changes: 3 additions & 0 deletions ibis/tests/expr/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self):
def do_connect(self):
pass

def disconnect(self):
pass

def table(self, name, **kwargs):
schema = self.get_schema(name)
node = ops.DatabaseTable(source=self, name=name, schema=schema)
Expand Down

0 comments on commit 32665af

Please sign in to comment.