diff --git a/python/ray/dataframe/dataframe.py b/python/ray/dataframe/dataframe.py index 9fbc73aa5b0e..2d749f20d3ab 100644 --- a/python/ray/dataframe/dataframe.py +++ b/python/ray/dataframe/dataframe.py @@ -40,7 +40,8 @@ _reindex_helper, _co_op_helper, _match_partitioning, - _concat_index) + _concat_index, + _correct_column_dtypes) from . import get_npartitions from .index_metadata import _IndexMetadata @@ -50,7 +51,8 @@ class DataFrame(object): def __init__(self, data=None, index=None, columns=None, dtype=None, copy=False, col_partitions=None, row_partitions=None, - block_partitions=None, row_metadata=None, col_metadata=None): + block_partitions=None, row_metadata=None, col_metadata=None, + dtypes_cache=None): """Distributed DataFrame object backed by Pandas dataframes. Args: @@ -74,6 +76,7 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, col_metadata (_IndexMetadata): Metadata for the new dataframe's columns """ + self._dtypes_cache = dtypes_cache # Check type of data and use appropriate constructor if data is not None or (col_partitions is None and @@ -83,6 +86,9 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, pd_df = pd.DataFrame(data=data, index=index, columns=columns, dtype=dtype, copy=copy) + # Cache dtypes + self._dtypes_cache = pd_df.dtypes + # TODO convert _partition_pandas_dataframe to block partitioning. row_partitions = \ _partition_pandas_dataframe(pd_df, @@ -117,6 +123,11 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, axis = 1 partitions = col_partitions axis_length = None + # All partitions will already have correct dtypes + self._dtypes_cache = [ + _deploy_func.remote(lambda df: df.dtypes, pd_df) + for pd_df in col_partitions + ] # TODO: write explicit tests for "short and wide" # column partitions @@ -151,6 +162,9 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, self._col_metadata = _IndexMetadata(self._block_partitions[0, :], index=columns, axis=1) + if self._dtypes_cache is None: + self._correct_dtypes() + def _get_row_partitions(self): return [_blocks_to_row.remote(*part) for part in self._block_partitions] @@ -414,6 +428,24 @@ def ftypes(self): result.index = self.columns return result + def _correct_dtypes(self): + """Corrects dtypes by concatenating column blocks and then splitting them + apart back into the original blocks. + + Also caches ObjectIDs for the dtypes of every column. + + Args: + block_partitions: arglist of column blocks. + """ + if self._block_partitions.shape[0] > 1: + self._block_partitions = np.array( + [_correct_column_dtypes._submit( + args=column, num_return_vals=len(column)) + for column in self._block_partitions.T]).T + + self._dtypes_cache = [_deploy_func.remote(lambda df: df.dtypes, pd_df) + for pd_df in self._block_partitions[0]] + @property def dtypes(self): """Get the dtypes for this DataFrame. @@ -421,12 +453,15 @@ def dtypes(self): Returns: The dtypes for this DataFrame. """ - # The dtypes are common across all partitions. - # The first partition will be enough. - result = ray.get(_deploy_func.remote(lambda df: df.dtypes, - self._row_partitions[0])) - result.index = self.columns - return result + assert self._dtypes_cache is not None + + if isinstance(self._dtypes_cache, list) and \ + isinstance(self._dtypes_cache[0], + ray.local_scheduler.ObjectID): + self._dtypes_cache = pd.concat(ray.get(self._dtypes_cache)) + self._dtypes_cache.index = self.columns + + return self._dtypes_cache @property def empty(self): @@ -500,6 +535,7 @@ def _update_inplace(self, row_partitions=None, col_partitions=None, if block_partitions is not None: self._block_partitions = block_partitions + elif row_partitions is not None: self._row_partitions = row_partitions @@ -520,6 +556,9 @@ def _update_inplace(self, row_partitions=None, col_partitions=None, self._row_metadata = _IndexMetadata( self._block_partitions[:, 0], index=index, axis=0) + # Update dtypes + self._correct_dtypes() + def add_prefix(self, prefix): """Add a prefix to each of the column names. @@ -570,7 +609,8 @@ def copy(self, deep=True): """ return DataFrame(block_partitions=self._block_partitions, columns=self.columns, - index=self.index) + index=self.index, + dtypes_cache=self.dtypes) def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, group_keys=True, squeeze=False, **kwargs): diff --git a/python/ray/dataframe/io.py b/python/ray/dataframe/io.py index c1abc0ec474c..cf91dbe5d647 100644 --- a/python/ray/dataframe/io.py +++ b/python/ray/dataframe/io.py @@ -261,7 +261,6 @@ def read_csv(filepath, df = _read_csv_with_offset.remote( filepath, start, end, kwargs=kwargs) df_obj_ids.append(df) - return DataFrame(row_partitions=df_obj_ids, columns=columns) diff --git a/python/ray/dataframe/utils.py b/python/ray/dataframe/utils.py index 78d728f69023..6b56db87cb84 100644 --- a/python/ray/dataframe/utils.py +++ b/python/ray/dataframe/utils.py @@ -379,3 +379,14 @@ def _match_partitioning(column_partition, lengths, index): @ray.remote def _concat_index(*index_parts): return index_parts[0].append(index_parts[1:]) + + +@ray.remote +def _correct_column_dtypes(*column): + """Corrects dtypes of a column by concatenating column partitions and + splitting the column back into partitions. + + Args: + """ + concat_column = pd.concat(column, copy=False) + return create_blocks_helper(concat_column, len(column), 1)