Skip to content

Commit

Permalink
feat(api): add .delta method for computing difference in units betw…
Browse files Browse the repository at this point in the history
…een two temporal values
  • Loading branch information
cpcloud committed Oct 2, 2023
1 parent ac85d11 commit 18617bf
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 9 deletions.
31 changes: 31 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,34 @@ def _count_distinct_star(t, op):
)


def _time_delta(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return f"TIME_DIFF({left}, {right}, {op.part.value.upper()})"


def _date_delta(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return f"DATE_DIFF({left}, {right}, {op.part.value.upper()})"


def _timestamp_delta(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
left_tz = op.left.dtype.timezone
right_tz = op.right.dtype.timezone
args = f"{left}, {right}, {op.part.value.upper()}"
if left_tz is None and right_tz is None:
return f"DATETIME_DIFF({args})"
elif left_tz is not None and right_tz is not None:
return f"TIMESTAMP_DIFF({args})"
else:
raise NotImplementedError(
"timestamp difference with mixed timezone/timezoneless values is not implemented"
)


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -893,6 +921,9 @@ def _count_distinct_star(t, op):
ops.CountDistinctStar: _count_distinct_star,
ops.Argument: lambda _, op: op.name,
ops.Unnest: unary("UNNEST"),
ops.TimeDelta: _time_delta,
ops.DateDelta: _date_delta,
ops.TimestampDelta: _timestamp_delta,
}

_invalid_operations = {
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,3 +1052,9 @@ def _scalar_udf(op, **kw) -> str:
@translate_val.register(ops.AggUDF)
def _agg_udf(op, *, where, **kw) -> str:
return agg[op.__full_name__](*kw.values(), where=where)


@translate_val.register(ops.DateDelta)
@translate_val.register(ops.TimestampDelta)
def _delta(op, *, part, left, right, **_):
return sg.exp.DateDiff(this=left, expression=right, unit=part)
8 changes: 8 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ def _try_cast(t, op):
return try_cast(arg, type_=to)


_temporal_delta = fixed_arity(
lambda part, start, end: sa.func.date_diff(part, end, start), 3
)


operation_registry.update(
{
ops.ArrayColumn: (
Expand Down Expand Up @@ -469,6 +474,9 @@ def _try_cast(t, op):
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
ops.ArrayIntersect: _array_intersect,
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
}
)

Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def _timestamp_truncate(t, op):
return sa.func.datetrunc(sa.text(_truncate_precisions[unit]), arg)


def _temporal_delta(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return sa.func.datediff(sa.literal_column(op.part.value.upper()), right, left)


operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)

Expand Down Expand Up @@ -197,6 +203,9 @@ def _timestamp_truncate(t, op):
ops.ExtractMicrosecond: fixed_arity(
lambda arg: sa.func.datepart(sa.literal_column("microsecond"), arg), 1
),
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
}
)

Expand Down
31 changes: 22 additions & 9 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,27 @@ def _interval_from_integer(t, op):


def _literal(_, op):
if op.dtype.is_interval():
if op.dtype.unit.short in {"ms", "ns"}:
dtype = op.dtype
value = op.value
if dtype.is_interval():
if dtype.unit.short in {"ms", "ns"}:
raise com.UnsupportedOperationError(
"MySQL does not allow operation "
f"with INTERVAL offset {op.dtype.unit}"
f"MySQL does not allow operation with INTERVAL offset {dtype.unit}"
)
text_unit = op.dtype.resolution.upper()
text_unit = dtype.resolution.upper()
sa_text = sa.text(f"INTERVAL :value {text_unit}")
return sa_text.bindparams(value=op.value)
elif op.dtype.is_binary():
return sa_text.bindparams(value=value)
elif dtype.is_binary():
# the cast to BINARY is necessary here, otherwise the data come back as
# Python strings
#
# This lets the database handle encoding rather than ibis
return sa.cast(sa.literal(op.value), type_=sa.BINARY())
return sa.cast(sa.literal(value), type_=sa.BINARY())
elif dtype.is_time():
return sa.func.maketime(
value.hour, value.minute, value.second + value.microsecond / 1e6
)
else:
value = op.value
with contextlib.suppress(AttributeError):
value = value.to_pydatetime()

Expand Down Expand Up @@ -167,6 +171,13 @@ def compiles_mysql_trim(element, compiler, **kw):
)


def _temporal_delta(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
part = sa.literal_column(op.part.value.upper())
return sa.func.timestampdiff(part, right, left)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -241,6 +252,8 @@ def compiles_mysql_trim(element, compiler, **kw):
ops.Strip: unary(lambda arg: _mysql_trim(arg, "both")),
ops.LStrip: unary(lambda arg: _mysql_trim(arg, "leading")),
ops.RStrip: unary(lambda arg: _mysql_trim(arg, "trailing")),
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
}
)

Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _literal(t, op):
return sa.func.timestamp_from_parts(*args)
elif dtype.is_date():
return sa.func.date_from_parts(value.year, value.month, value.day)
elif dtype.is_time():
nanos = value.microsecond * 1_000
return sa.func.time_from_parts(value.hour, value.minute, value.second, nanos)
elif dtype.is_array():
return sa.func.array_construct(*value)
elif dtype.is_map() or dtype.is_struct():
Expand Down Expand Up @@ -461,6 +464,15 @@ def _map_get(t, op):
ops.Levenshtein: fixed_arity(sa.func.editdistance, 2),
ops.ArraySort: unary(sa.func.ibis_udfs.public.array_sort),
ops.ArrayRepeat: fixed_arity(sa.func.ibis_udfs.public.array_repeat, 2),
ops.TimeDelta: fixed_arity(
lambda part, left, right: sa.func.timediff(part, right, left), 3
),
ops.DateDelta: fixed_arity(
lambda part, left, right: sa.func.datediff(part, right, left), 3
),
ops.TimestampDelta: fixed_arity(
lambda part, left, right: sa.func.timestampdiff(part, right, left), 3
),
}
)

Expand Down
59 changes: 59 additions & 0 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,3 +2400,62 @@ def test_timestamp_precision_output(con, ts, scale, unit):
result = con.execute(expr)
expected = pd.Timestamp(ts).floor(unit)
assert result == expected


@pytest.mark.notimpl(
[
"dask",
"datafusion",
"druid",
"flink",
"impala",
"oracle",
"pandas",
"polars",
"pyspark",
"sqlite",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["postgres"],
reason="postgres doesn't have any easy way to accurately compute the delta in specific units",
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize(
("start", "end", "unit", "expected"),
[
param(
ibis.time("01:58:00"),
ibis.time("23:59:59"),
"hour",
22,
id="time",
marks=[
pytest.mark.notimpl(
["clickhouse"],
raises=NotImplementedError,
reason="time types not yet implemented in ibis for the clickhouse backend",
)
],
),
param(ibis.date("1992-09-30"), ibis.date("1992-10-01"), "day", 1, id="date"),
param(
ibis.timestamp("1992-09-30 23:59:59"),
ibis.timestamp("1992-10-01 01:58:00"),
"hour",
2,
id="timestamp",
marks=[
pytest.mark.notimpl(
["mysql"],
raises=com.OperationNotDefinedError,
reason="timestampdiff rounds after subtraction and mysql doesn't have a date_trunc function",
)
],
),
],
)
def test_delta(con, start, end, unit, expected):
expr = end.delta(start, unit)
assert con.execute(expr) == expected
12 changes: 12 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def _array_intersect(t, op):
)


_temporal_delta = fixed_arity(
lambda part, left, right: sa.func.date_diff(
part, sa.func.date_trunc(part, right), sa.func.date_trunc(part, left)
),
3,
)

operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -501,6 +508,11 @@ def _array_intersect(t, op):
),
ops.Levenshtein: fixed_arity(sa.func.levenshtein_distance, 2),
ops.ArrayIntersect: _array_intersect,
# trino truncates _after_ the delta, whereas many other backends
# truncates each operand
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
}
)

Expand Down
24 changes: 24 additions & 0 deletions ibis/expr/operations/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,28 @@ class BetweenTime(Between):
upper_bound: Value[dt.Time | dt.String]


class TemporalDelta(Value):
part: Value[dt.String]
shape = rlz.shape_like("args")
dtype = dt.int64


@public
class TimeDelta(TemporalDelta):
left: Value[dt.Time]
right: Value[dt.Time]


@public
class DateDelta(TemporalDelta):
left: Value[dt.Date]
right: Value[dt.Date]


@public
class TimestampDelta(TemporalDelta):
left: Value[dt.Timestamp]
right: Value[dt.Timestamp]


public(ExtractTimestampField=ExtractTemporalField)
Loading

0 comments on commit 18617bf

Please sign in to comment.