From 36e48cc7c4107123d0bca9f0f663e17bc02d91b8 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 12 Sep 2022 16:15:43 -0500 Subject: [PATCH] feat(duckdb): register tables from pandas/pyarrow objects --- ibis/backends/duckdb/__init__.py | 32 ++++++++++++++------- ibis/backends/duckdb/tests/test_register.py | 24 ++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index b34db683ba23..b2230b02883f 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import itertools import os import warnings from pathlib import Path @@ -27,6 +28,8 @@ _generate_view_code = RegexDispatcher("_register") _dialect = sa.dialects.postgresql.dialect() +_gen_table_names = (f"registered_table{i:d}" for i in itertools.count()) + def _name_from_path(path: Path) -> str: base, *_ = path.name.partition(os.extsep) @@ -146,28 +149,37 @@ def do_connect( def register( self, - path: str | Path, + source: str | Path | Any, table_name: str | None = None, ) -> ir.Table: - """Register an external file as a table in the current connection - database + """Register a data source as a table in the current database. Parameters ---------- - path - Name of the parquet or CSV file + source + The data source. May be a path to a file or directory of + parquet/csv files, a pandas dataframe, or a pyarrow table or + dataset. table_name - Name for the created table. Defaults to filename if not given. - Any dashes in a user-provided or generated name will be - replaced with underscores. + An optional name to use for the created table. This defaults to the + filename if a path (with hyphens replaced with underscores), or + sequentially generated name otherwise. Returns ------- ir.Table The just-registered table """ - view, table_name = _generate_view_code(path, table_name=table_name) - self.con.execute(view) + if isinstance(source, (str, Path)): + sql, table_name = _generate_view_code( + source, table_name=table_name + ) + self.con.execute(sql) + else: + if table_name is None: + table_name = next(_gen_table_names) + self.con.execute("register", (table_name, source)) + return self.table(table_name) def fetch_from_cursor( diff --git a/ibis/backends/duckdb/tests/test_register.py b/ibis/backends/duckdb/tests/test_register.py index c46d37889c84..94b7e322d739 100644 --- a/ibis/backends/duckdb/tests/test_register.py +++ b/ibis/backends/duckdb/tests/test_register.py @@ -107,3 +107,27 @@ def test_register_parquet( table = con.table(out_table_name) assert table.count().execute() + + +def test_register_pandas(): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + + con = ibis.duckdb.connect() + + t = con.register(df) + assert t.x.sum().execute() == 6 + + t = con.register(df, "my_table") + assert t.op().name == "my_table" + assert t.x.sum().execute() == 6 + + +def test_register_pyarrow_tables(): + pa = pytest.importorskip("pyarrow") + pa_t = pa.Table.from_pydict({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + + con = ibis.duckdb.connect() + + t = con.register(pa_t) + assert t.x.sum().execute() == 6