Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 115 additions & 3 deletions python/ray/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4464,9 +4464,121 @@ def remote_func(df):

def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
errors='raise', try_cast=False, raise_on_error=None):
raise NotImplementedError(
"To contribute to Pandas on Ray, please visit "
"github.com/ray-project/ray.")
"""Replaces values not meeting condition with values in other.

Args:
cond: A condition to be met, can be callable, array-like or a
DataFrame.
other: A value or DataFrame of values to use for setting this.
inplace: Whether or not to operate inplace.
axis: The axis to apply over. Only valid when a Series is passed
as other.
level: The MultiLevel index level to apply over.
errors: Whether or not to raise errors. Does nothing in Pandas.
try_cast: Try to cast the result back to the input type.
raise_on_error: Whether to raise invalid datatypes (deprecated).

Returns:
A new DataFrame with the replaced values.
"""

inplace = validate_bool_kwarg(inplace, 'inplace')

if isinstance(other, pd.Series) and axis is None:
raise ValueError("Must specify axis=0 or 1")

if level is not None:
raise NotImplementedError("Multilevel Index not yet supported on "
"Pandas on Ray.")

axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0

cond = cond(self) if callable(cond) else cond

if not isinstance(cond, DataFrame):
if not hasattr(cond, 'shape'):
cond = np.asanyarray(cond)
if cond.shape != self.shape:
raise ValueError("Array conditional must be same shape as "
"self")
cond = DataFrame(cond, index=self.index, columns=self.columns)

zipped_partitions = self._copartition(cond, self.index)
args = (False, axis, level, errors, try_cast, raise_on_error)

@ray.remote
def where_helper(left, cond, other, *args):

left = pd.concat(ray.get(left.tolist()), axis=1)
# We have to reset the index and columns here because we are coming
# from blocks and the axes are set according to the blocks. We have
# already correctly copartitioned everything, so there's no
# correctness problems with doing this.
left.reset_index(inplace=True, drop=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since everything is concatenated into row partitions, can you only reset the column index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to reset the index here because that's what other is relying on.

left.columns = pd.RangeIndex(len(left.columns))

cond = pd.concat(ray.get(cond.tolist()), axis=1)
cond.reset_index(inplace=True, drop=True)
cond.columns = pd.RangeIndex(len(cond.columns))

if isinstance(other, np.ndarray):
other = pd.concat(ray.get(other.tolist()), axis=1)
other.reset_index(inplace=True, drop=True)
other.columns = pd.RangeIndex(len(other.columns))

return left.where(cond, other, *args)

if isinstance(other, DataFrame):
other_zipped = (v for k, v in self._copartition(other,
self.index))

new_partitions = [where_helper.remote(k, v, next(other_zipped),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can k, v be converted to lists and passed in by reference? Ray will automatically deserialize then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not without merging them together, then also passing the length of the left. Performance-wise it's not much different.

*args)
for k, v in zipped_partitions]

# Series has to be treated specially because we're operating on row
# partitions from here on.
elif isinstance(other, pd.Series):
if axis == 0:
# Pandas determines which index to use based on axis.
other = other.reindex(self.index)
other.index = pd.RangeIndex(len(other))

# Since we're working on row partitions, we have to partition
# the Series based on the partitioning of self (since both
# self and cond are co-partitioned by self.
other_builder = []
for length in self._row_metadata._lengths:
other_builder.append(other[:length])
other = other[length:]
# Resetting the index here ensures that we apply each part
# to the correct row within the partitions.
other.index = pd.RangeIndex(len(other))

other = (obj for obj in other_builder)

new_partitions = [where_helper.remote(k, v, next(other,
pd.Series()),
*args)
for k, v in zipped_partitions]
else:
other = other.reindex(self.columns)
other.index = pd.RangeIndex(len(other))
new_partitions = [where_helper.remote(k, v, other, *args)
for k, v in zipped_partitions]

else:
new_partitions = [where_helper.remote(k, v, other, *args)
for k, v in zipped_partitions]

if inplace:
self._update_inplace(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)
else:
return DataFrame(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)

def xs(self, key, axis=0, level=None, drop_level=True):
raise NotImplementedError(
Expand Down
32 changes: 29 additions & 3 deletions python/ray/dataframe/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df):


def test_where():
ray_df = create_test_dataframe()
pandas_df = pd.DataFrame(np.random.randn(100, 10),
columns=list('abcdefghij'))
ray_df = rdf.DataFrame(pandas_df)

with pytest.raises(NotImplementedError):
ray_df.where(None)
pandas_cond_df = pandas_df % 5 < 2
ray_cond_df = ray_df % 5 < 2

pandas_result = pandas_df.where(pandas_cond_df, -pandas_df)
ray_result = ray_df.where(ray_cond_df, -ray_df)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df.loc[3]

pandas_result = pandas_df.where(pandas_cond_df, other, axis=1)
ray_result = ray_df.where(ray_cond_df, other, axis=1)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df['e']

pandas_result = pandas_df.where(pandas_cond_df, other, axis=0)
ray_result = ray_df.where(ray_cond_df, other, axis=0)

assert ray_df_equals_pandas(ray_result, pandas_result)

pandas_result = pandas_df.where(pandas_df < 2, True)
ray_result = ray_df.where(ray_df < 2, True)

assert ray_df_equals_pandas(ray_result, pandas_result)


def test_xs():
Expand Down
5 changes: 3 additions & 2 deletions python/ray/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def to_pandas(df):
A new pandas DataFrame.
"""
if df._row_partitions is not None:
pd_df = pd.concat(ray.get(df._row_partitions))
print("Yes")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

pd_df = pd.concat(ray.get(df._row_partitions), copy=False)
else:
pd_df = pd.concat(ray.get(df._col_partitions),
axis=1)
axis=1, copy=False)
pd_df.index = df.index
pd_df.columns = df.columns
return pd_df
Expand Down