Skip to content

Commit

Permalink
fix(sql): generate consistent pivot_longer semantics in the presenc…
Browse files Browse the repository at this point in the history
…e of multiple `unnest`s
  • Loading branch information
cpcloud committed Apr 5, 2023
1 parent 4a0ff1f commit 6bc301a
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 22 deletions.
9 changes: 6 additions & 3 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.common.collections import FrozenDict

if geospatial_supported:
import geoalchemy2 as ga
Expand All @@ -29,10 +30,12 @@ def compiles_array(element, compiler, **kw):


class StructType(sat.UserDefinedType):
cache_ok = True

def __init__(self, fields: Mapping[str, sat.TypeEngine]) -> None:
self.fields = {
name: sat.to_instance(type) for name, type in dict(fields).items()
}
self.fields = FrozenDict(
{name: sat.to_instance(typ) for name, typ in fields.items()}
)


@compiles(StructType, "default")
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def _array_filter(t, op):
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Unnest: unary(sa.func.unnest),
}
)

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):
_has_reduction_filter_syntax = True
_dialect_name = "postgresql"

# it does support it, but we can't use it because of support for pivot
supports_unnest_in_select = False


rewrites = PostgreSQLExprTranslator.rewrites

Expand Down
60 changes: 59 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as pg
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction

import ibis.backends.base.sql.registry.geospatial as geo
import ibis.common.exceptions as com
Expand Down Expand Up @@ -488,6 +490,60 @@ def _arbitrary(t, op):
return t._reduction(func, op)


class struct_field(GenericFunction):
inherit_cache = True


@compiles(struct_field)
def compile_struct_field_postgresql(element, compiler, **kw):
arg, field = element.clauses
return f"({compiler.process(arg, **kw)}).{field.name}"


def _struct_field(t, op):
arg = op.arg
idx = arg.output_dtype.names.index(op.field) + 1
field_name = sa.literal_column(f"f{idx:d}")
return struct_field(
t.translate(arg), field_name, type_=t.get_sqla_type(op.output_dtype)
)


def _struct_column(t, op):
types = op.output_dtype.types
return sa.func.row(
# we have to cast here, otherwise postgres refuses to allow the statement
*map(t.translate, map(ops.Cast, op.values, types)),
type_=t.get_sqla_type(
dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)})
),
)


def _unnest(t, op):
arg = op.arg
row_type = arg.output_dtype.value_type

types = getattr(row_type, "types", (row_type,))

is_struct = row_type.is_struct()
derived = (
sa.func.unnest(t.translate(arg))
.table_valued(
*(
sa.column(f"f{i:d}", stype)
for i, stype in enumerate(map(t.get_sqla_type, types), start=1)
)
)
.render_derived(with_types=is_struct)
)

# wrap in a row column so that we can return a single column from this rule
if not is_struct:
return derived.c[0]
return sa.func.row(*derived.c)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -594,7 +650,7 @@ def _arbitrary(t, op):
),
ops.ArrayConcat: fixed_arity(sa.sql.expression.ColumnElement.concat, 2),
ops.ArrayRepeat: _array_repeat,
ops.Unnest: unary(sa.func.unnest),
ops.Unnest: _unnest,
ops.Covariance: _covar,
ops.Correlation: _corr,
ops.BitwiseXor: _bitwise_op("#"),
Expand Down Expand Up @@ -639,5 +695,7 @@ def _arbitrary(t, op):
ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)),
ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2),
ops.Arbitrary: _arbitrary,
ops.StructColumn: _struct_column,
ops.StructField: _struct_field,
}
)
19 changes: 17 additions & 2 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def query(t, group_cols):
)
def test_pivot_longer(backend):
diamonds = backend.diamonds
df = diamonds.execute()
res = diamonds.pivot_longer(s.c("x", "y", "z"), names_to="pos", values_to="xyz")
assert res.schema().names == (
"carat",
Expand All @@ -999,8 +1000,22 @@ def test_pivot_longer(backend):
"pos",
"xyz",
)
df = res.limit(5).execute()
assert not df.empty
expected = pd.melt(
df,
id_vars=[
"carat",
"cut",
"color",
"clarity",
"depth",
"table",
"price",
],
value_vars=list('xyz'),
var_name="pos",
value_name="xyz",
)
assert len(res.execute()) == len(expected)


@pytest.mark.notyet(["datafusion"], raises=com.OperationNotDefinedError)
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _timestamp(_, itype):
return TIMESTAMP(precision=itype.scale, timezone=bool(itype.timezone))


@compiles(TIMESTAMP, "trino")
@compiles(TIMESTAMP)
def compiles_timestamp(typ, compiler, **kw):
result = "TIMESTAMP"

Expand All @@ -150,7 +150,7 @@ def compiles_timestamp(typ, compiler, **kw):
return result


@compiles(ROW, "trino")
@compiles(ROW)
def _compiles_row(element, compiler, **kw):
# TODO: @compiles should live in the dialect
quote = compiler.dialect.identifier_preparer.quote
Expand All @@ -168,7 +168,7 @@ def _map(dialect, itype):
)


@compiles(MAP, "trino")
@compiles(MAP)
def compiles_map(typ, compiler, **kw):
# TODO: @compiles should live in the dialect
key_type = compiler.process(typ.key_type, **kw)
Expand All @@ -191,7 +191,7 @@ def _real(*_):
return sa.REAL()


@compiles(DOUBLE, "trino")
@compiles(DOUBLE)
@compiles(sa.REAL, "trino")
def _floating(element, compiler, **kw):
return type(element).__name__.upper()
9 changes: 8 additions & 1 deletion ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,14 @@ def _round(t, op):
def _unnest(t, op):
arg = op.arg
name = arg.name
return sa.func.unnest(t.translate(arg)).table_valued(name).render_derived().c[name]
row_type = op.arg.output_dtype.value_type
names = getattr(row_type, "names", (name,))
rd = sa.func.unnest(t.translate(arg)).table_valued(*names).render_derived()
# wrap in a row column so that we can return a single column from this rule
if len(names) == 1:
return rd.c[0]
row = sa.func.row(*(rd.c[name] for name in names))
return sa.cast(row, t.get_sqla_type(row_type))


def _where(t, op):
Expand Down
25 changes: 14 additions & 11 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3104,21 +3104,24 @@ def pivot_longer(
elif isinstance(values_transform, Deferred):
values_transform = values_transform.resolve

names_map = {name: [] for name in names_to}
values = []
pieces = []

for pivot_col in pivot_cols:
col_name = pivot_col.get_name()
match_result = names_pattern.match(col_name)
for name, value in zip(names_to, match_result.groups()):
transformer = names_transform[name]
names_map[name].append(transformer(value))
values.append(values_transform(pivot_col))

new_cols = {key: ibis.array(value).unnest() for key, value in names_map.items()}
new_cols[values_to] = ibis.array(values).unnest()

return self.select(~pivot_sel, **new_cols)
row = {
name: names_transform[name](value)
for name, value in zip(names_to, match_result.groups())
}
row[values_to] = values_transform(pivot_col)
pieces.append(ibis.struct(row))

# nest into an array of structs to zip unnests together
pieces = ibis.array(pieces)

return self.select(~pivot_sel, __pivoted__=pieces.unnest()).unpack(
"__pivoted__"
)

@util.experimental
def pivot_wider(
Expand Down

0 comments on commit 6bc301a

Please sign in to comment.