Skip to content

Commit 1737eb0

Browse files
committed
Complete NumPy universial functions for DataFrames
1 parent eb763ea commit 1737eb0

File tree

2 files changed

+121
-30
lines changed

2 files changed

+121
-30
lines changed

databricks/koalas/frame.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial, reduce
2727
import sys
2828
from itertools import zip_longest
29-
from typing import Any, Optional, List, Tuple, Union, Generic, TypeVar, Iterable, Dict
29+
from typing import Any, Optional, List, Tuple, Union, Generic, TypeVar, Iterable, Dict, Callable
3030

3131
import numpy as np
3232
import pandas as pd
@@ -8499,6 +8499,46 @@ def __dir__(self):
84998499
def __iter__(self):
85008500
return iter(self.columns)
85018501

8502+
# NDArray Compat
8503+
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any):
8504+
# TODO: is it possible to deduplicate it with '_map_series_op'?
8505+
if (all(isinstance(inp, DataFrame) for inp in inputs)
8506+
and any(inp is not inputs[0] for inp in inputs)):
8507+
# binary only
8508+
assert len(inputs) == 2
8509+
this = inputs[0]
8510+
that = inputs[1]
8511+
if this._internal.column_index_level != that._internal.column_index_level:
8512+
raise ValueError('cannot join with no overlapping index names')
8513+
8514+
# Different DataFrames
8515+
def apply_op(kdf, this_column_index, that_column_index):
8516+
for this_idx, that_idx in zip(this_column_index, that_column_index):
8517+
yield (ufunc(kdf[this_idx], kdf[that_idx], **kwargs), this_idx)
8518+
8519+
return align_diff_frames(apply_op, this, that, fillna=True, how="full")
8520+
else:
8521+
# DataFrame and Series
8522+
applied = []
8523+
this = inputs[0]
8524+
assert all(inp is this for inp in inputs if isinstance(inp, DataFrame))
8525+
8526+
for idx in this._internal.column_index:
8527+
arguments = []
8528+
for inp in inputs:
8529+
arguments.append(inp[idx] if isinstance(inp, DataFrame) else inp)
8530+
# both binary and unary.
8531+
applied.append(ufunc(*arguments, **kwargs))
8532+
8533+
sdf = this._sdf.select(
8534+
this._internal.index_scols + [c._scol for c in applied])
8535+
internal = this._internal.copy(sdf=sdf,
8536+
column_index=[c._internal.column_index[0]
8537+
for c in applied],
8538+
column_scols=[scol_for(sdf, c._internal.data_columns[0])
8539+
for c in applied])
8540+
return DataFrame(internal)
8541+
85028542
if sys.version_info >= (3, 7):
85038543
def __class_getitem__(cls, params):
85048544
# This is a workaround to support variadic generic in DataFrame in Python 3.7.

databricks/koalas/tests/test_numpy_compat.py

+80-29
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,29 @@
2323

2424

2525
class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils):
26+
blacklist = [
27+
# Koalas does not currently support
28+
"conj",
29+
"conjugate",
30+
"isnat",
31+
"matmul",
32+
"frexp",
33+
34+
# Values are close enough but tests failed.
35+
"arccos",
36+
"exp",
37+
"expm1",
38+
"log", # flaky
39+
"log10", # flaky
40+
"log1p", # flaky
41+
"modf",
42+
"floor_divide", # flaky
43+
44+
# Results seem inconsistent in a different version of, I (Hyukjin) suspect, PyArrow.
45+
# From PyArrow 0.15, seems it returns the correct results via PySpark. Probably we
46+
# can enable it later when Koalas switches to PyArrow 0.15 completely.
47+
"left_shift",
48+
]
2649

2750
@property
2851
def pdf(self):
@@ -49,12 +72,17 @@ def test_np_add_index(self):
4972
p_index = self.pdf.index
5073
self.assert_eq(np.add(k_index, k_index), np.add(p_index, p_index))
5174

52-
def test_np_unsupported(self):
75+
def test_np_unsupported_series(self):
5376
kdf = self.kdf
5477
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
5578
np.sqrt(kdf.a, kdf.b)
5679

57-
def test_np_spark_compat(self):
80+
def test_np_unsupported_frame(self):
81+
kdf = self.kdf
82+
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
83+
np.sqrt(kdf, kdf)
84+
85+
def test_np_spark_compat_series(self):
5886
# Use randomly generated dataFrame
5987
pdf = pd.DataFrame(
6088
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=['a', 'b'])
@@ -63,33 +91,9 @@ def test_np_spark_compat(self):
6391
kdf = ks.from_pandas(pdf)
6492
kdf2 = ks.from_pandas(pdf2)
6593

66-
blacklist = [
67-
# Koalas does not currently support
68-
"conj",
69-
"conjugate",
70-
"isnat",
71-
"matmul",
72-
"frexp",
73-
74-
# Values are close enough but tests failed.
75-
"arccos",
76-
"exp",
77-
"expm1",
78-
"log", # flaky
79-
"log10", # flaky
80-
"log1p", # flaky
81-
"modf",
82-
"floor_divide", # flaky
83-
84-
# Results seem inconsistent in a different version of, I (Hyukjin) suspect, PyArrow.
85-
# From PyArrow 0.15, seems it returns the correct results via PySpark. Probably we
86-
# can enable it later when Koalas switches to PyArrow 0.15 completely.
87-
"left_shift",
88-
]
89-
9094
for np_name, spark_func in unary_np_spark_mappings.items():
9195
np_func = getattr(np, np_name)
92-
if np_name not in blacklist:
96+
if np_name not in self.blacklist:
9397
try:
9498
# unary ufunc
9599
self.assert_eq(np_func(pdf.a), np_func(kdf.a), almost=True)
@@ -98,7 +102,7 @@ def test_np_spark_compat(self):
98102

99103
for np_name, spark_func in binary_np_spark_mappings.items():
100104
np_func = getattr(np, np_name)
101-
if np_name not in blacklist:
105+
if np_name not in self.blacklist:
102106
try:
103107
# binary ufunc
104108
self.assert_eq(
@@ -113,7 +117,7 @@ def test_np_spark_compat(self):
113117
set_option('compute.ops_on_diff_frames', True)
114118
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
115119
np_func = getattr(np, np_name)
116-
if np_name not in blacklist:
120+
if np_name not in self.blacklist:
117121
try:
118122
# binary ufunc
119123
self.assert_eq(
@@ -123,3 +127,50 @@ def test_np_spark_compat(self):
123127
raise AssertionError("Test in '%s' function was failed." % np_name) from e
124128
finally:
125129
reset_option('compute.ops_on_diff_frames')
130+
131+
def test_np_spark_compat_frame(self):
132+
# Use randomly generated dataFrame
133+
pdf = pd.DataFrame(
134+
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=['a', 'b'])
135+
pdf2 = pd.DataFrame(
136+
np.random.randint(-100, 100, size=(len(pdf), len(pdf.columns))), columns=['a', 'b'])
137+
kdf = ks.from_pandas(pdf)
138+
kdf2 = ks.from_pandas(pdf2)
139+
140+
for np_name, spark_func in unary_np_spark_mappings.items():
141+
np_func = getattr(np, np_name)
142+
if np_name not in self.blacklist:
143+
try:
144+
# unary ufunc
145+
self.assert_eq(np_func(pdf), np_func(kdf), almost=True)
146+
except Exception as e:
147+
raise AssertionError("Test in '%s' function was failed." % np_name) from e
148+
149+
for np_name, spark_func in binary_np_spark_mappings.items():
150+
np_func = getattr(np, np_name)
151+
if np_name not in self.blacklist:
152+
try:
153+
# binary ufunc
154+
self.assert_eq(
155+
np_func(pdf, pdf), np_func(kdf, kdf), almost=True)
156+
self.assert_eq(
157+
np_func(pdf, 1), np_func(kdf, 1), almost=True)
158+
except Exception as e:
159+
raise AssertionError("Test in '%s' function was failed." % np_name) from e
160+
161+
# Test only top 5 for now. 'compute.ops_on_diff_frames' option increases too much time.
162+
try:
163+
set_option('compute.ops_on_diff_frames', True)
164+
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
165+
np_func = getattr(np, np_name)
166+
if np_name not in self.blacklist:
167+
try:
168+
# binary ufunc
169+
self.assert_eq(
170+
np_func(pdf, pdf2).sort_index(),
171+
np_func(kdf, kdf2).sort_index(), almost=True)
172+
173+
except Exception as e:
174+
raise AssertionError("Test in '%s' function was failed." % np_name) from e
175+
finally:
176+
reset_option('compute.ops_on_diff_frames')

0 commit comments

Comments
 (0)