Skip to content

Commit

Permalink
fix(ddl): use column names, not position, for insertion order (#9264)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
gforsyth and cpcloud authored May 31, 2024
1 parent 77aaecd commit 3506f40
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 19 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ def insert(
elif not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

query = sge.insert(self.compile(obj), into=name, dialect=self.name)

query = self._build_insert_query(target=name, source=obj)
external_tables = self._collect_in_memory_tables(obj, {})
external_data = self._normalize_external_tables(external_tables)
return self.con.command(query.sql(self.name), external_data=external_data)
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/impala/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def insert(
if not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

if not set(self.columns).difference(obj.columns):
# project out using column order of parent table
# if column names match
obj = obj.select(self.columns)

self._client._run_pre_execute_hooks(obj)

expr = obj
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,13 +1138,13 @@ def insert(
if not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

table = sg.table(table_name, db=db, catalog=catalog, quoted=True)
self._run_pre_execute_hooks(obj)
query = sg.exp.insert(
expression=self.compile(obj),
into=table,
columns=[sg.to_identifier(col, quoted=True) for col in obj.columns],
dialect=self.name,

query = self._build_insert_query(
target=table_name, source=obj, db=db, catalog=catalog
)
table = sg.table(
table_name, db=db, catalog=catalog, quoted=self.compiler.quoted
)

statements = []
Expand Down
35 changes: 26 additions & 9 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,20 +431,37 @@ def insert(

self._run_pre_execute_hooks(obj)

query = self._build_insert_query(
target=table_name, source=obj, db=db, catalog=catalog
)

with self._safe_raw_sql(query):
pass

def _build_insert_query(
self, *, target: str, source, db: str | None = None, catalog: str | None = None
):
compiler = self.compiler
quoted = compiler.quoted
# Compare the columns between the target table and the object to be inserted
# If they don't match, assume auto-generated column names and use positional
# ordering.
source_cols = source.columns
columns = (
source_cols
if not set(target_cols := self.get_schema(target).names).difference(
source_cols
)
else target_cols
)

query = sge.insert(
expression=self.compile(obj),
into=sg.table(table_name, db=db, catalog=catalog, quoted=quoted),
columns=[
sg.to_identifier(col, quoted=quoted)
for col in self.get_schema(table_name).names
],
expression=self.compile(source),
into=sg.table(target, db=db, catalog=catalog, quoted=quoted),
columns=[sg.to_identifier(col, quoted=quoted) for col in columns],
dialect=compiler.dialect,
)

with self._safe_raw_sql(query):
pass
return query

def truncate_table(
self, name: str, database: str | None = None, schema: str | None = None
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,12 @@ def insert(
obj = ibis.memtable(obj)

self._run_pre_execute_hooks(obj)
expr = self._to_sqlglot(obj)
insert_stmt = sge.Insert(this=table, expression=expr).sql(self.name)

query = self._build_insert_query(
target=table_name, source=obj, catalog=database
)
insert_stmt = query.sql(self.name)

with self.begin() as cur:
if overwrite:
cur.execute(f"DELETE FROM {table.sql(self.name)}")
Expand Down
41 changes: 41 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,3 +1746,44 @@ def test_schema_with_caching(alltypes):

assert pt1.schema() == t1.schema()
assert pt2.schema() == t2.schema()


@pytest.mark.notyet(
["druid"], raises=NotImplementedError, reason="doesn't support create_table"
)
@pytest.mark.notyet(["pandas", "dask", "polars"], reason="Doesn't support insert")
@pytest.mark.notyet(
["datafusion"], reason="Doesn't support table creation from records"
)
@pytest.mark.parametrize(
"first_row, second_row",
[
param([{"a": 1, "b": 2}], [{"b": 22, "a": 11}], id="column order reversed"),
param([{"a": 1, "b": 2}], [{"a": 11, "b": 22}], id="column order matching"),
param(
[{"a": 1, "b": 2}],
[(11, 22)],
marks=[
pytest.mark.notimpl(
["impala"],
reason="Impala DDL has strict validation checks on schema",
)
],
id="auto generated cols",
),
],
)
def test_insert_using_col_name_not_position(con, first_row, second_row, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
table_name = gen_name("table")
con.create_table(table_name, first_row)
con.insert(table_name, second_row)

result = con.table(table_name).order_by("a").to_pyarrow()
expected_result = pa.table({"a": [1, 11], "b": [2, 22]})

assert result.equals(expected_result)

# Ideally we'd use a temp table for this test, but several backends don't
# support them and it's nice to know that data are being inserted correctly.
con.drop_table(table_name)

0 comments on commit 3506f40

Please sign in to comment.