Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def cast_column_to_type(col: dd.Series, expected_type: str):

current_float = pd.api.types.is_float_dtype(current_type)
expected_integer = pd.api.types.is_integer_dtype(expected_type)
current_timedelta_type = pd.api.types.is_timedelta64_dtype(current_type)
if current_float and expected_integer:
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
Expand All @@ -296,5 +297,10 @@ def cast_column_to_type(col: dd.Series, expected_type: str):
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))

if current_timedelta_type and expected_integer:
res = col.dt.total_seconds() * 1000
else:
res = col.astype(expected_type)

logger.debug(f"Need to cast from {current_type} to {expected_type}")
return col.astype(expected_type)
return res
50 changes: 50 additions & 0 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
import operator
import re
Expand Down Expand Up @@ -200,7 +201,28 @@ def cast(self, operand, rex=None) -> SeriesOrScalar:
output_type = sql_to_python_type(output_type.upper())

return_column = cast_column_to_type(operand, output_type)
if return_column is None:
return operand
else:
return return_column


class ReinterpretOperation(Operation):
"""The cast operator"""

needs_rex = True

def __init__(self):
super().__init__(self.cast)

def cast(self, operand, rex=None) -> SeriesOrScalar:
if not is_frame(operand):
return operand

output_type = str(rex.getType())
output_type = sql_to_python_type(output_type.upper())

return_column = cast_column_to_type(operand, output_type)
if return_column is None:
return operand
else:
Expand Down Expand Up @@ -642,6 +664,32 @@ def random_function(self, partition, random_state, kwargs):
return random_state.randint(size=len(partition), low=0, **kwargs)


class IntDivisionOperator(Operation):
"""
Truncated integer division (so -1 / 2 = 0).
This is only used for internal calculations,
which are created by Calcite.
"""

def __init__(self):
super().__init__(self.div)

def div(self, lhs, rhs):
result = lhs / rhs

# Specialized code for literals like "1000µs"
# For some reasons, Calcite decides to represent
# 1000µs as 1000µs * 1000 / 1000
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
return result
else: # pragma: no cover
result = da.trunc(result)
return result


class SearchOperation(Operation):
"""
Search is a special operation in SQL, which allows to write "range-like"
Expand Down Expand Up @@ -701,8 +749,10 @@ class RexCallPlugin(BaseRexPlugin):
"*": ReduceOperation(operation=operator.mul),
"is distinct from": NotOperation().of(IsNotDistinctOperation()),
"is not distinct from": IsNotDistinctOperation(),
"/int": IntDivisionOperator(),
# special operations
"cast": CastOperation(),
"reinterpret": ReinterpretOperation(),
"case": CaseOperation(),
"like": LikeOperation(),
"similar to": SimilarOperation(),
Expand Down
62 changes: 62 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,65 @@ def test_date_functions(c):
FROM df
"""
).compute()


def test_timestampdiff(c):
# single value test
query = (
"SELECT timestampdiff(SECOND, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res0,"
"timestampdiff(MINUTE, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res1,"
"timestampdiff(HOUR, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res2,"
"timestampdiff(DAY, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res3,"
"timestampdiff(MONTH, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res4,"
"timestampdiff(YEAR, CAST('2002-03-07' AS TIMESTAMP),CAST('2002-06-05' AS TIMESTAMP)) as res5"
)

df = c.sql(query).compute()
assert df["res0"][0] == 7776000
assert df["res1"][0] == 129600
assert df["res2"][0] == 2160
assert df["res3"][0] == 90
assert df["res4"][0] == 2
assert df["res5"][0] == 0

# dataframe test

test = pd.DataFrame(
{
"a": ["2002-06-05 00:00:00", "2002-09-01 00:00:00", "2002-12-03 00:00:00"],
"b": ["2002-06-07 00:00:00", "2003-06-05 00:00:00", "2002-06-05 00:00:00"],
}
)

c.create_table("test", test)
query = (
"SELECT timestampdiff(MICROSECOND, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as ms,"
"timestampdiff(SECOND, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as sec,"
"timestampdiff(MINUTE, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as minn,"
"timestampdiff(HOUR, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as hr,"
"timestampdiff(DAY, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as dayy "
"FROM test"
)

ddf = c.sql(query).compute()

expected_df = pd.DataFrame(
{
"ms": {0: -1001308160, 1: -1242226688, 2: 424075264},
"sec": {0: -172800, 1: -23932800, 2: 15638400},
"minn": {0: -2880, 1: -398880, 2: 260640},
"hr": {0: -48, 1: -6648, 2: 4344},
"dayy": {0: -2, 1: -277, 2: 181},
}
)

assert_frame_equal(ddf, expected_df, check_dtype=False)

# as of now year and month was not working
query = (
"SELECT timestampdiff(MONTH, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as monthh,"
"timestampdiff(YEAR, CAST(b AS TIMESTAMP),CAST(a AS TIMESTAMP)) as yearr "
"FROM test"
)
with pytest.raises(Exception):
ddf = c.sql(query).compute()