Skip to content

Commit

Permalink
feat(datafusion): datafusion enhancements (#9544)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored Jul 15, 2024
1 parent cba7367 commit f11ca43
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
11 changes: 6 additions & 5 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
src = sge.Create(
this=table,
kind="VIEW",
expression=sg.parse_one(query, read="datafusion"),
expression=sg.parse_one(query, read=self.dialect),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
)

Expand Down Expand Up @@ -537,13 +537,13 @@ def make_gen():
# convert the renamed + casted columns into a record batch
pa.RecordBatch.from_struct_array(
# rename columns to match schema because datafusion lowercases things
pa.RecordBatch.from_arrays(batch.columns, names=names)
pa.RecordBatch.from_arrays(batch.to_pyarrow().columns, names=names)
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
for batch in frame.execute_stream()
)

return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen())
Expand Down Expand Up @@ -628,7 +628,8 @@ def create_table(
)
)
elif obj is not None:
_read_in_memory(obj, name, self, overwrite=overwrite)
table_ident = sg.table(name, db=database, quoted=quoted).sql(self.dialect)
_read_in_memory(obj, table_ident, self, overwrite=overwrite)
return self.table(name, database=database)
else:
query = None
Expand Down Expand Up @@ -687,7 +688,7 @@ def truncate_table(
table_loc = self._warn_and_create_table_loc(database, schema)
catalog, db = self._to_catalog_db_tuple(table_loc)

ident = sg.table(name, db=db, catalog=catalog).sql(self.name)
ident = sg.table(name, db=db, catalog=catalog).sql(self.dialect)
with self._safe_raw_sql(sge.delete(ident)):
pass

Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/datafusion/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def test_register_dataset(conn):
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(dataset, "my_table")
assert conn.table("my_table").x.sum().execute() == 6


def test_create_table_with_uppercase_name(conn):
tab = pa.table({"x": [1, 2, 3]})
conn.create_table("MY_TABLE", tab)
assert conn.table("MY_TABLE").x.sum().execute() == 6

0 comments on commit f11ca43

Please sign in to comment.