diff --git a/python/ray/dataframe/dataframe.py b/python/ray/dataframe/dataframe.py index 561111677ee8..b3223682a839 100644 --- a/python/ray/dataframe/dataframe.py +++ b/python/ray/dataframe/dataframe.py @@ -4464,9 +4464,105 @@ def remote_func(df): def where(self, cond, other=np.nan, inplace=False, axis=None, level=None, errors='raise', try_cast=False, raise_on_error=None): - raise NotImplementedError( - "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + """Replaces values not meeting condition with values in other. + + Args: + cond: A condition to be met, can be callable, array-like or a + DataFrame. + other: A value or DataFrame of values to use for setting this. + inplace: Whether or not to operate inplace. + axis: The axis to apply over. Only valid when a Series is passed + as other. + level: The MultiLevel index level to apply over. + errors: Whether or not to raise errors. Does nothing in Pandas. + try_cast: Try to cast the result back to the input type. + raise_on_error: Whether to raise invalid datatypes (deprecated). + + Returns: + A new DataFrame with the replaced values. + """ + + inplace = validate_bool_kwarg(inplace, 'inplace') + + if isinstance(other, pd.Series) and axis is None: + raise ValueError("Must specify axis=0 or 1") + + if level is not None: + raise NotImplementedError("Multilevel Index not yet supported on " + "Pandas on Ray.") + + axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0 + + cond = cond(self) if callable(cond) else cond + + if not isinstance(cond, DataFrame): + if not hasattr(cond, 'shape'): + cond = np.asanyarray(cond) + if cond.shape != self.shape: + raise ValueError("Array conditional must be same shape as " + "self") + cond = DataFrame(cond, index=self.index, columns=self.columns) + + zipped_partitions = self._copartition(cond, self.index) + args = (False, axis, level, errors, try_cast, raise_on_error) + + if isinstance(other, DataFrame): + other_zipped = (v for k, v in self._copartition(other, + self.index)) + + new_partitions = [_where_helper.remote(k, v, next(other_zipped), + self.columns, cond.columns, + other.columns, *args) + for k, v in zipped_partitions] + + # Series has to be treated specially because we're operating on row + # partitions from here on. + elif isinstance(other, pd.Series): + if axis == 0: + # Pandas determines which index to use based on axis. + other = other.reindex(self.index) + other.index = pd.RangeIndex(len(other)) + + # Since we're working on row partitions, we have to partition + # the Series based on the partitioning of self (since both + # self and cond are co-partitioned by self. + other_builder = [] + for length in self._row_metadata._lengths: + other_builder.append(other[:length]) + other = other[length:] + # Resetting the index here ensures that we apply each part + # to the correct row within the partitions. + other.index = pd.RangeIndex(len(other)) + + other = (obj for obj in other_builder) + + new_partitions = [_where_helper.remote(k, v, next(other, + pd.Series()), + self.columns, + cond.columns, + None, *args) + for k, v in zipped_partitions] + else: + other = other.reindex(self.columns) + new_partitions = [_where_helper.remote(k, v, other, + self.columns, + cond.columns, + None, *args) + for k, v in zipped_partitions] + + else: + new_partitions = [_where_helper.remote(k, v, other, self.columns, + cond.columns, None, *args) + for k, v in zipped_partitions] + + if inplace: + self._update_inplace(row_partitions=new_partitions, + row_metadata=self._row_metadata, + col_metadata=self._col_metadata) + else: + return DataFrame(row_partitions=new_partitions, + row_metadata=self._row_metadata, + col_metadata=self._col_metadata) def xs(self, key, axis=0, level=None, drop_level=True): raise NotImplementedError( @@ -5093,3 +5189,27 @@ def _merge_columns(left_columns, right_columns, *args): return pd.DataFrame(columns=left_columns, index=[0], dtype='uint8').merge( pd.DataFrame(columns=right_columns, index=[0], dtype='uint8'), *args).columns + + +@ray.remote +def _where_helper(left, cond, other, left_columns, cond_columns, + other_columns, *args): + + left = pd.concat(ray.get(left.tolist()), axis=1) + # We have to reset the index and columns here because we are coming + # from blocks and the axes are set according to the blocks. We have + # already correctly copartitioned everything, so there's no + # correctness problems with doing this. + left.reset_index(inplace=True, drop=True) + left.columns = left_columns + + cond = pd.concat(ray.get(cond.tolist()), axis=1) + cond.reset_index(inplace=True, drop=True) + cond.columns = cond_columns + + if isinstance(other, np.ndarray): + other = pd.concat(ray.get(other.tolist()), axis=1) + other.reset_index(inplace=True, drop=True) + other.columns = other_columns + + return left.where(cond, other, *args) diff --git a/python/ray/dataframe/test/test_dataframe.py b/python/ray/dataframe/test/test_dataframe.py index 1fa63465d87a..d944c5bb5fa4 100644 --- a/python/ray/dataframe/test/test_dataframe.py +++ b/python/ray/dataframe/test/test_dataframe.py @@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df): def test_where(): - ray_df = create_test_dataframe() + pandas_df = pd.DataFrame(np.random.randn(100, 10), + columns=list('abcdefghij')) + ray_df = rdf.DataFrame(pandas_df) - with pytest.raises(NotImplementedError): - ray_df.where(None) + pandas_cond_df = pandas_df % 5 < 2 + ray_cond_df = ray_df % 5 < 2 + + pandas_result = pandas_df.where(pandas_cond_df, -pandas_df) + ray_result = ray_df.where(ray_cond_df, -ray_df) + + assert ray_df_equals_pandas(ray_result, pandas_result) + + other = pandas_df.loc[3] + + pandas_result = pandas_df.where(pandas_cond_df, other, axis=1) + ray_result = ray_df.where(ray_cond_df, other, axis=1) + + assert ray_df_equals_pandas(ray_result, pandas_result) + + other = pandas_df['e'] + + pandas_result = pandas_df.where(pandas_cond_df, other, axis=0) + ray_result = ray_df.where(ray_cond_df, other, axis=0) + + assert ray_df_equals_pandas(ray_result, pandas_result) + + pandas_result = pandas_df.where(pandas_df < 2, True) + ray_result = ray_df.where(ray_df < 2, True) + + assert ray_df_equals_pandas(ray_result, pandas_result) def test_xs(): diff --git a/python/ray/dataframe/utils.py b/python/ray/dataframe/utils.py index 0ca927f571be..b9954e0e49f9 100644 --- a/python/ray/dataframe/utils.py +++ b/python/ray/dataframe/utils.py @@ -107,11 +107,7 @@ def to_pandas(df): Returns: A new pandas DataFrame. """ - if df._row_partitions is not None: - pd_df = pd.concat(ray.get(df._row_partitions)) - else: - pd_df = pd.concat(ray.get(df._col_partitions), - axis=1) + pd_df = pd.concat(ray.get(df._row_partitions), copy=False) pd_df.index = df.index pd_df.columns = df.columns return pd_df