Skip to content

Commit

Permalink
Standardize exceptions of arithmetic operations on Datetime-like data (
Browse files Browse the repository at this point in the history
…#2101)

Standardize exceptions of arithmetic operations on Datetime-like data
  • Loading branch information
xinrong-meng authored Mar 31, 2021
1 parent b413231 commit fcf21dc
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
31 changes: 31 additions & 0 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def __add__(self, other) -> Union["Series", "Index"]:
or isinstance(other, str)
):
raise TypeError("string addition can only be applied to string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("addition can not be applied to date times.")

if isinstance(self.spark.data_type, StringType):
# Concatenate string columns
if isinstance(other, IndexOpsMixin) and isinstance(other.spark.data_type, StringType):
Expand Down Expand Up @@ -390,6 +394,9 @@ def __mul__(self, other) -> Union["Series", "Index"]:
if isinstance(other, str):
raise TypeError("multiplication can not be applied to a string literal.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("multiplication can not be applied to date times.")

if (
isinstance(self.spark.data_type, IntegralType)
and isinstance(other, IndexOpsMixin)
Expand Down Expand Up @@ -434,6 +441,9 @@ def __truediv__(self, other) -> Union["Series", "Index"]:
):
raise TypeError("division can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("division can not be applied to date times.")

def truediv(left, right):
return F.when(F.lit(right != 0) | F.lit(right).isNull(), left.__div__(right)).otherwise(
F.when(F.lit(left == np.inf) | F.lit(left == -np.inf), left).otherwise(
Expand All @@ -451,6 +461,9 @@ def __mod__(self, other) -> Union["Series", "Index"]:
):
raise TypeError("modulo can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("modulo can not be applied to date times.")

def mod(left, right):
return ((left % right) + right) % right

Expand All @@ -461,6 +474,9 @@ def __radd__(self, other) -> Union["Series", "Index"]:
if not isinstance(self.spark.data_type, StringType) and isinstance(other, str):
raise TypeError("string addition can only be applied to string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("addition can not be applied to date times.")

if isinstance(self.spark.data_type, StringType):
if isinstance(other, str):
return self._with_new_scol(
Expand Down Expand Up @@ -507,6 +523,9 @@ def __rmul__(self, other) -> Union["Series", "Index"]:
if isinstance(other, str):
raise TypeError("multiplication can not be applied to a string literal.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("multiplication can not be applied to date times.")

if isinstance(self.spark.data_type, StringType):
if isinstance(other, int):
return column_op(SF.repeat)(self, other)
Expand All @@ -521,6 +540,9 @@ def __rtruediv__(self, other) -> Union["Series", "Index"]:
if isinstance(self.spark.data_type, StringType) or isinstance(other, str):
raise TypeError("division can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("division can not be applied to date times.")

def rtruediv(left, right):
return F.when(left == 0, F.lit(np.inf).__div__(right)).otherwise(
F.lit(right).__truediv__(left)
Expand Down Expand Up @@ -552,6 +574,9 @@ def __floordiv__(self, other) -> Union["Series", "Index"]:
):
raise TypeError("division can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("division can not be applied to date times.")

def floordiv(left, right):
return F.when(F.lit(right is np.nan), np.nan).otherwise(
F.when(
Expand All @@ -569,6 +594,9 @@ def __rfloordiv__(self, other) -> Union["Series", "Index"]:
if isinstance(self.spark.data_type, StringType) or isinstance(other, str):
raise TypeError("division can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("division can not be applied to date times.")

def rfloordiv(left, right):
return F.when(F.lit(left == 0), F.lit(np.inf).__div__(right)).otherwise(
F.when(F.lit(left) == np.nan, np.nan).otherwise(F.floor(F.lit(right).__div__(left)))
Expand All @@ -580,6 +608,9 @@ def __rmod__(self, other) -> Union["Series", "Index"]:
if isinstance(self.spark.data_type, StringType) or isinstance(other, str):
raise TypeError("modulo can not be applied on string series or literals.")

if isinstance(self.spark.data_type, TimestampType):
raise TypeError("modulo can not be applied to date times.")

def rmod(left, right):
return ((right % left) + left) % left

Expand Down
31 changes: 31 additions & 0 deletions databricks/koalas/tests/indexes/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,34 @@ def test_indexer_at_time(self):
NotImplementedError,
lambda: ks.DatetimeIndex([0]).indexer_at_time("00:00:00", asof=True),
)

def test_arithmetic_op_exceptions(self):
for kidx, pidx in self.idx_pairs:
py_datetime = pidx.to_pydatetime()
for other in [1, 0.1, kidx, kidx.to_series().reset_index(drop=True), py_datetime]:
expected_err_msg = "addition can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx + other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other + kidx)

expected_err_msg = "multiplication can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx * other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other * kidx)

expected_err_msg = "division can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx / other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other / kidx)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx // other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other // kidx)

expected_err_msg = "modulo can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx % other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other % kidx)

expected_err_msg = "datetime subtraction can only be applied to datetime series."

for other in [1, 0.1]:
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx - other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other - kidx)

self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kidx - other)
self.assertRaises(NotImplementedError, lambda: py_datetime - kidx)
33 changes: 33 additions & 0 deletions databricks/koalas/tests/test_series_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,39 @@ def test_timestamp_subtraction(self):
with self.assertRaisesRegex(TypeError, expected_error_message):
1 - kdf["a"]

def test_arithmetic_op_exceptions(self):
kser = self.ks_start_date
py_datetime = self.pd_start_date.dt.to_pydatetime()
datetime_index = ks.Index(self.pd_start_date)

for other in [1, 0.1, kser, datetime_index, py_datetime]:
expected_err_msg = "addition can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser + other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other + kser)

expected_err_msg = "multiplication can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser * other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other * kser)

expected_err_msg = "division can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser / other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other / kser)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser // other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other // kser)

expected_err_msg = "modulo can not be applied to date times."
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser % other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other % kser)

expected_err_msg = "datetime subtraction can only be applied to datetime series."

for other in [1, 0.1]:
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser - other)
self.assertRaisesRegex(TypeError, expected_err_msg, lambda: other - kser)

self.assertRaisesRegex(TypeError, expected_err_msg, lambda: kser - other)
self.assertRaises(NotImplementedError, lambda: py_datetime - kser)

def test_date_subtraction(self):
pdf = self.pdf1
kdf = ks.from_pandas(pdf)
Expand Down

0 comments on commit fcf21dc

Please sign in to comment.