diff --git a/python/ray/dataframe/dataframe.py b/python/ray/dataframe/dataframe.py index 3253cd0a172b..15bb686d6f3c 100644 --- a/python/ray/dataframe/dataframe.py +++ b/python/ray/dataframe/dataframe.py @@ -3,6 +3,7 @@ from __future__ import print_function import pandas as pd +import functools from pandas.api.types import is_scalar from pandas.util._validators import validate_bool_kwarg from pandas.core.index import _ensure_index_from_sequences @@ -15,7 +16,8 @@ is_bool_dtype, is_list_like, is_numeric_dtype, - is_timedelta64_dtype) + is_timedelta64_dtype, + _get_dtype_from_object) from pandas.core.indexing import check_bool_indexer import warnings @@ -977,9 +979,42 @@ def assign(self, **kwargs): "github.com/ray-project/ray.") def astype(self, dtype, copy=True, errors='raise', **kwargs): - raise NotImplementedError( - "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + if isinstance(dtype, dict): + if (not set(dtype.keys()).issubset(set(self.columns)) and + errors == 'raise'): + raise KeyError( + "Only a column name can be used for the key in" + "a dtype mappings argument.") + columns = list(dtype.keys()) + col_idx = [(self.columns.get_loc(columns[i]), columns[i]) + if columns[i] in self.columns + else (columns[i], columns[i]) + for i in range(len(columns))] + new_dict = {} + for idx, key in col_idx: + new_dict[idx] = dtype[key] + new_rows = _map_partitions(lambda df, dt: df.astype(dtype=dt, + copy=True, + errors=errors, + **kwargs), + self._row_partitions, new_dict) + if copy: + return DataFrame(row_partitions=new_rows, + columns=self.columns, + index=self.index) + self._row_partitions = new_rows + else: + new_blocks = [_map_partitions(lambda d: d.astype(dtype=dtype, + copy=True, + errors=errors, + **kwargs), + block) + for block in self._block_partitions] + if copy: + return DataFrame(block_partitions=new_blocks, + columns=self.columns, + index=self.index) + self._block_partitions = new_blocks def at_time(self, time, asof=False): raise NotImplementedError( @@ -2688,9 +2723,42 @@ def select(self, crit, axis=0): "github.com/ray-project/ray.") def select_dtypes(self, include=None, exclude=None): - raise NotImplementedError( - "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + # Validates arguments for whether both include and exclude are None or + # if they are disjoint. Also invalidates string dtypes. + pd.DataFrame().select_dtypes(include, exclude) + + if include and not is_list_like(include): + include = [include] + elif not include: + include = [] + + if exclude and not is_list_like(exclude): + exclude = [exclude] + elif not exclude: + exclude = [] + + sel = tuple(map(set, (include, exclude))) + + include, exclude = map( + lambda x: set(map(_get_dtype_from_object, x)), sel) + + include_these = pd.Series(not bool(include), index=self.columns) + exclude_these = pd.Series(not bool(exclude), index=self.columns) + + def is_dtype_instance_mapper(column, dtype): + return column, functools.partial(issubclass, dtype.type) + + for column, f in itertools.starmap(is_dtype_instance_mapper, + self.dtypes.iteritems()): + if include: # checks for the case of empty include or exclude + include_these[column] = any(map(f, include)) + if exclude: + exclude_these[column] = not any(map(f, exclude)) + + dtype_indexer = include_these & exclude_these + indicate = [i for i in range(len(dtype_indexer.values)) + if not dtype_indexer.values[i]] + return self.drop(columns=self.columns[indicate], inplace=False) def sem(self, axis=None, skipna=None, level=None, ddof=1, numeric_only=None, **kwargs): @@ -3317,9 +3385,16 @@ def __getattr__(self, key): raise e def __setitem__(self, key, value): - raise NotImplementedError( - "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + if not isinstance(key, str): + raise NotImplementedError( + "To contribute to Pandas on Ray, please visit " + "github.com/ray-project/ray.") + if key not in self.columns: + self.insert(loc=len(self.columns), column=key, value=value) + else: + loc = self.columns.get_loc(key) + self.__delitem__(key) + self.insert(loc=loc, column=key, value=value) def __len__(self): """Gets the length of the dataframe. diff --git a/python/ray/dataframe/test/test_dataframe.py b/python/ray/dataframe/test/test_dataframe.py index dc3b83d35e15..42801dc315cd 100644 --- a/python/ray/dataframe/test/test_dataframe.py +++ b/python/ray/dataframe/test/test_dataframe.py @@ -905,10 +905,28 @@ def test_assign(): def test_astype(): - ray_df = create_test_dataframe() + td = TestData() + ray_df_frame = from_pandas(td.frame, 2) + our_df_casted = ray_df_frame.astype(np.int32) + expected_df_casted = pd.DataFrame(td.frame.values.astype(np.int32), + index=td.frame.index, + columns=td.frame.columns) - with pytest.raises(NotImplementedError): - ray_df.astype(None) + assert(ray_df_equals_pandas(our_df_casted, expected_df_casted)) + + our_df_casted = ray_df_frame.astype(np.float64) + expected_df_casted = pd.DataFrame(td.frame.values.astype(np.float64), + index=td.frame.index, + columns=td.frame.columns) + + assert(ray_df_equals_pandas(our_df_casted, expected_df_casted)) + + our_df_casted = ray_df_frame.astype(str) + expected_df_casted = pd.DataFrame(td.frame.values.astype(str), + index=td.frame.index, + columns=td.frame.columns) + + assert(ray_df_equals_pandas(our_df_casted, expected_df_casted)) def test_at_time(): @@ -2524,10 +2542,25 @@ def test_select(): def test_select_dtypes(): - ray_df = create_test_dataframe() - - with pytest.raises(NotImplementedError): - ray_df.select_dtypes() + df = pd.DataFrame({'test1': list('abc'), + 'test2': np.arange(3, 6).astype('u1'), + 'test3': np.arange(8.0, 11.0, dtype='float64'), + 'test4': [True, False, True], + 'test5': pd.date_range('now', periods=3).values, + 'test6': list(range(5, 8))}) + include = np.float, 'integer' + exclude = np.bool_, + rd = from_pandas(df, 2) + r = rd.select_dtypes(include=include, exclude=exclude) + + e = df[["test2", "test3", "test6"]] + assert(ray_df_equals_pandas(r, e)) + + try: + rdf.DataFrame().select_dtypes() + assert(False) + except ValueError: + assert(True) def test_sem():