diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 75eeb3e5c3..e715a2fe0c 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -4008,6 +4008,94 @@ def duplicated(self, subset=None, keep="first") -> "Series": ) ) + # TODO: support other as DataFrame or array-like + def dot(self, other: "Series") -> "Series": + """ + Compute the matrix multiplication between the DataFrame and other. + + This method computes the matrix product between the DataFrame and the + values of an other Series + + It can also be called using ``self @ other`` in Python >= 3.5. + + .. note:: This method is based on an expensive operation due to the nature + of big data. Internally it needs to generate each row for each value, and + then group twice - it is a huge operation. To prevent misusage, this method + has the 'compute.max_rows' default limit of input length, and raises a ValueError. + + >>> from databricks.koalas.config import option_context + >>> with option_context( + ... 'compute.max_rows', 1000, "compute.ops_on_diff_frames", True + ... ): # doctest: +NORMALIZE_WHITESPACE + ... kdf = ks.DataFrame({'a': range(1001)}) + ... kser = ks.Series([2], index=['a']) + ... kdf.dot(kser) + Traceback (most recent call last): + ... + ValueError: Current DataFrame has more then the given limit 1000 rows. + Please set 'compute.max_rows' by using 'databricks.koalas.config.set_option' + to retrieve to retrieve more than 1000 rows. Note that, before changing the + 'compute.max_rows', this operation is considerably expensive. + + Parameters + ---------- + other : Series + The other object to compute the matrix product with. + + Returns + ------- + Series + Return the matrix product between self and other as a Series. + + See Also + -------- + Series.dot: Similar method for Series. + + Notes + ----- + The dimensions of DataFrame and other must be compatible in order to + compute the matrix multiplication. In addition, the column names of + DataFrame and the index of other must contain the same values, as they + will be aligned prior to the multiplication. + + The dot method for Series computes the inner product, instead of the + matrix product here. + + Examples + -------- + >>> from databricks.koalas.config import set_option, reset_option + >>> set_option("compute.ops_on_diff_frames", True) + >>> kdf = ks.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + >>> kser = ks.Series([1, 1, 2, 1]) + >>> kdf.dot(kser) + 0 -4 + 1 5 + dtype: int64 + + Note how shuffling of the objects does not change the result. + + >>> kser2 = kser.reindex([1, 0, 2, 3]) + >>> kdf.dot(kser2) + 0 -4 + 1 5 + dtype: int64 + >>> kdf @ kser2 + 0 -4 + 1 5 + dtype: int64 + >>> reset_option("compute.ops_on_diff_frames") + """ + if not isinstance(other, ks.Series): + raise TypeError("Unsupported type {}".format(type(other).__name__)) + else: + return cast(ks.Series, other.dot(self.transpose())).rename(None) + + def __matmul__(self, other): + """ + Matrix multiplication using binary `@` operator in Python>=3.5. + """ + return self.dot(other) + def to_koalas(self, index_col: Optional[Union[str, List[str]]] = None) -> "DataFrame": """ Converts the existing DataFrame into a Koalas DataFrame. diff --git a/databricks/koalas/missing/frame.py b/databricks/koalas/missing/frame.py index 95761dda7d..7b0de83460 100644 --- a/databricks/koalas/missing/frame.py +++ b/databricks/koalas/missing/frame.py @@ -47,7 +47,6 @@ class _MissingPandasLikeDataFrame(object): convert_dtypes = _unsupported_function("convert_dtypes") corrwith = _unsupported_function("corrwith") cov = _unsupported_function("cov") - dot = _unsupported_function("dot") ewm = _unsupported_function("ewm") first = _unsupported_function("first") infer_objects = _unsupported_function("infer_objects") diff --git a/databricks/koalas/tests/test_ops_on_diff_frames.py b/databricks/koalas/tests/test_ops_on_diff_frames.py index 212c79e541..d20a4c73de 100644 --- a/databricks/koalas/tests/test_ops_on_diff_frames.py +++ b/databricks/koalas/tests/test_ops_on_diff_frames.py @@ -852,7 +852,7 @@ def test_multi_index_column_assignment_frame(self): with self.assertRaisesRegex(KeyError, "Key length \\(3\\) exceeds index depth \\(2\\)"): kdf[("1", "2", "3")] = ks.Series([100, 200, 300, 200]) - def test_dot(self): + def test_series_dot(self): pser = pd.Series([90, 91, 85], index=[2, 4, 1]) kser = ks.from_pandas(pser) pser_other = pd.Series([90, 91, 85], index=[2, 4, 1]) @@ -917,6 +917,57 @@ def test_dot(self): pdf = kdf.to_pandas() self.assert_eq(kser.dot(kdf), pser.dot(pdf)) + def test_frame_dot(self): + pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + kdf = ks.from_pandas(pdf) + + pser = pd.Series([1, 1, 2, 1]) + kser = ks.from_pandas(pser) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # Index reorder + pser = pser.reindex([1, 0, 2, 3]) + kser = ks.from_pandas(pser) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # ser with name + pser.name = "ser" + kser = ks.from_pandas(pser) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # df with MultiIndex as column (ser with MultiIndex) + arrays = [[1, 1, 2, 2], ["red", "blue", "red", "blue"]] + pidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color")) + pser = pd.Series([1, 1, 2, 1], index=pidx) + pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]], columns=pidx) + kdf = ks.from_pandas(pdf) + kser = ks.from_pandas(pser) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # df with Index as column (ser with Index) + pidx = pd.Index([1, 2, 3, 4], name="number") + pser = pd.Series([1, 1, 2, 1], index=pidx) + pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]], columns=pidx) + kdf = ks.from_pandas(pdf) + kser = ks.from_pandas(pser) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # df with Index + pdf.index = pd.Index(["x", "y"], name="char") + kdf = ks.from_pandas(pdf) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + # df with MultiIndex + pdf.index = pd.MultiIndex.from_arrays([[1, 1], ["red", "blue"]], names=("number", "color")) + kdf = ks.from_pandas(pdf) + self.assert_eq(kdf.dot(kser), pdf.dot(pser)) + + pdf = pd.DataFrame([[1, 2], [3, 4]]) + kdf = ks.from_pandas(pdf) + self.assert_eq(kdf.dot(kdf[0]), pdf.dot(pdf[0])) + self.assert_eq(kdf.dot(kdf[0] * 10), pdf.dot(pdf[0] * 10)) + self.assert_eq((kdf + 1).dot(kdf[0] * 10), (pdf + 1).dot(pdf[0] * 10)) + def test_to_series_comparison(self): kidx1 = ks.Index([1, 2, 3, 4, 5]) kidx2 = ks.Index([1, 2, 3, 4, 5]) diff --git a/docs/source/reference/frame.rst b/docs/source/reference/frame.rst index 6dfde64dc1..9dbd6b0a42 100644 --- a/docs/source/reference/frame.rst +++ b/docs/source/reference/frame.rst @@ -98,6 +98,7 @@ Binary operator functions DataFrame.ge DataFrame.ne DataFrame.eq + DataFrame.dot Function application, GroupBy & Window --------------------------------------