diff --git a/python/ray/dataframe/dataframe.py b/python/ray/dataframe/dataframe.py index 37cb90bc2d24..bdcb54cf1ff2 100644 --- a/python/ray/dataframe/dataframe.py +++ b/python/ray/dataframe/dataframe.py @@ -43,8 +43,8 @@ _co_op_helper, _match_partitioning, _concat_index, - _correct_column_dtypes, - fix_blocks_dimensions) + fix_blocks_dimensions, + _compile_remote_dtypes) from . import get_npartitions from .index_metadata import _IndexMetadata from .iterator import PartitionIterator @@ -169,7 +169,7 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, index=columns, axis=1) if self._dtypes_cache is None: - self._correct_dtypes() + self._get_remote_dtypes() def _get_frame_data(self): data = {} @@ -455,23 +455,11 @@ def ftypes(self): result.index = self.columns return result - def _correct_dtypes(self): - """Corrects dtypes by concatenating column blocks and then splitting them - apart back into the original blocks. - - Also caches ObjectIDs for the dtypes of every column. - - Args: - block_partitions: arglist of column blocks. + def _get_remote_dtypes(self): + """Finds and caches ObjectIDs for the dtypes of each column partition. """ - if self._block_partitions.shape[0] > 1: - self._block_partitions = np.array( - [_correct_column_dtypes._submit( - args=column, num_return_vals=len(column)) - for column in self._block_partitions.T]).T - - self._dtypes_cache = [_deploy_func.remote(lambda df: df.dtypes, pd_df) - for pd_df in self._block_partitions[0]] + self._dtypes_cache = [_compile_remote_dtypes.remote(*column) + for column in self._block_partitions.T] @property def dtypes(self): @@ -584,7 +572,7 @@ def _update_inplace(self, row_partitions=None, col_partitions=None, self._block_partitions[:, 0], index=index, axis=0) # Update dtypes - self._correct_dtypes() + self._get_remote_dtypes() def add_prefix(self, prefix): """Add a prefix to each of the column names. diff --git a/python/ray/dataframe/utils.py b/python/ray/dataframe/utils.py index 911ae911bce0..0da3e0ff8927 100644 --- a/python/ray/dataframe/utils.py +++ b/python/ray/dataframe/utils.py @@ -456,20 +456,15 @@ def _concat_index(*index_parts): return index_parts[0].append(index_parts[1:]) -@ray.remote -def _correct_column_dtypes(*column): - """Corrects dtypes of a column by concatenating column partitions and - splitting the column back into partitions. - - Args: - """ - concat_column = pd.concat(column, copy=False) - return create_blocks_helper(concat_column, len(column), 1) - - def fix_blocks_dimensions(blocks, axis): """Checks that blocks is 2D, and adds a dimension if not. """ if blocks.ndim < 2: return np.expand_dims(blocks, axis=axis ^ 1) return blocks + + +@ray.remote +def _compile_remote_dtypes(*column_of_blocks): + small_dfs = [df.loc[0:0] for df in column_of_blocks] + return pd.concat(small_dfs, copy=False).dtypes