Skip to content

Commit

Permalink
Allow for 0 length tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jul 16, 2024
1 parent 683af8d commit ca55235
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 38 deletions.
9 changes: 7 additions & 2 deletions src/daft-csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ fn tables_concat(mut tables: Vec<Table>) -> DaftResult<Table> {
Series::concat(series_to_cat.as_slice())
})
.collect::<DaftResult<Vec<_>>>()?;
Table::new(first_table.schema.clone(), new_series)
Table::new(
first_table.schema.clone(),
new_series,
tables.iter().map(|t| t.len()).sum(),
)
}

async fn read_csv_single_into_table(
Expand Down Expand Up @@ -508,7 +512,8 @@ fn parse_into_column_array_chunk_stream(
)
})
.collect::<DaftResult<Vec<Series>>>()?;
Ok(Table::new_unchecked(read_schema, chunk))
let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0);
Ok(Table::new_unchecked(read_schema, chunk, num_rows))
})();
let _ = send.send(result);
});
Expand Down
1 change: 1 addition & 0 deletions src/daft-execution/src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ mod tests {
.boxed(),
)
.unwrap()],
input_meta.num_rows.unwrap(),
)
.unwrap()]),
None,
Expand Down
1 change: 1 addition & 0 deletions src/daft-execution/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub(crate) fn mock_micropartition(num_rows: usize) -> MicroPartition {
.boxed(),
)
.unwrap()],
num_rows,
)
.unwrap()]),
None,
Expand Down
5 changes: 4 additions & 1 deletion src/daft-json/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ impl<'a> JsonReader<'a> {
})
.collect::<IndexMap<_, _>>();

let mut num_rows = 0;
for record in iter {
let value = record.map_err(|e| super::Error::JsonDeserializationError {
string: e.to_string(),
Expand Down Expand Up @@ -221,6 +222,8 @@ impl<'a> JsonReader<'a> {
.into());
}
}

num_rows += 1;
}
let columns = columns
.into_values()
Expand All @@ -234,7 +237,7 @@ impl<'a> JsonReader<'a> {
})
.collect::<DaftResult<Vec<_>>>()?;

let tbl = Table::new_unchecked(self.schema.clone(), columns);
let tbl = Table::new_unchecked(self.schema.clone(), columns, num_rows);

if let Some(pred) = &self.predicate {
tbl.filter(&[pred.clone()])
Expand Down
13 changes: 11 additions & 2 deletions src/daft-json/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ pub(crate) fn tables_concat(mut tables: Vec<Table>) -> DaftResult<Table> {
Series::concat(series_to_cat.as_slice())
})
.collect::<DaftResult<Vec<_>>>()?;
Table::new(first_table.schema.clone(), new_series)
Table::new(
first_table.schema.clone(),
new_series,
tables.iter().map(|t| t.len()).sum(),
)
}

async fn read_json_single_into_table(
Expand Down Expand Up @@ -424,6 +428,7 @@ fn parse_into_column_array_chunk_stream(
let schema = schema.clone();
let daft_schema = daft_schema.clone();
let daft_fields = daft_fields.clone();
let num_rows = records.len();
tokio::spawn(async move {
let (send, recv) = tokio::sync::oneshot::channel();
rayon::spawn(move || {
Expand All @@ -450,7 +455,11 @@ fn parse_into_column_array_chunk_stream(
)
})
.collect::<DaftResult<Vec<_>>>()?;
Ok(Table::new_unchecked(daft_schema.clone(), all_series))
Ok(Table::new_unchecked(
daft_schema.clone(),
all_series,
num_rows,
))
})();
let _ = send.send(result);
});
Expand Down
6 changes: 5 additions & 1 deletion src/daft-parquet/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,11 @@ impl ParquetFileReader {
.collect::<DaftResult<Vec<_>>>()?;
let daft_schema = daft_core::schema::Schema::try_from(self.arrow_schema.as_ref())?;

Table::new(daft_schema, all_series)
Table::new(
daft_schema,
all_series,
self.row_ranges.as_ref().iter().map(|rr| rr.num_rows).sum(),
)
}

pub async fn read_from_ranges_into_arrow_arrays(
Expand Down
2 changes: 2 additions & 0 deletions src/daft-plan/src/source_info/file_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ impl FileInfos {
arrow2::array::PrimitiveArray::<i64>::from(&self.num_rows).to_boxed(),
))?,
];
let num_rows = columns.first().map(|s| s.len()).unwrap();
Table::new(
Schema::new(columns.iter().map(|s| s.field().clone()).collect())?,
columns,
num_rows,
)
}
}
Expand Down
89 changes: 63 additions & 26 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,26 @@ pub use python::register_modules;
pub struct Table {
pub schema: SchemaRef,
columns: Vec<Series>,
num_rows: usize,
}

impl Table {
pub fn new<S: Into<SchemaRef>>(schema: S, columns: Vec<Series>) -> DaftResult<Self> {
pub fn new<S: Into<SchemaRef>>(
schema: S,
columns: Vec<Series>,
num_rows: usize,
) -> DaftResult<Self> {
let schema: SchemaRef = schema.into();
if schema.fields.len() != columns.len() {
return Err(DaftError::SchemaMismatch(format!("While building a Table, we found that the number of fields did not match between the schema and the input columns.\n {:?}\n vs\n {:?}", schema.fields.len(), columns.len())));
}
let mut num_rows = 1;

for (field, series) in schema.fields.values().zip(columns.iter()) {
if field != series.field() {
return Err(DaftError::SchemaMismatch(format!("While building a Table, we found that the Schema Field and the Series Field did not match. schema field: {field} vs series field: {}", series.field())));
}
if (series.len() != 1) && (series.len() != num_rows) {
if num_rows == 1 {
num_rows = series.len();
} else {
return Err(DaftError::ValueError(format!("While building a Table, we found that the Series lengths did not match. Series named: {} had length: {} vs rest of the DataFrame had length: {}", field.name, series.len(), num_rows)));
}
return Err(DaftError::ValueError(format!("While building a Table, we found that the Series lengths did not match. Series named: {} had length: {} vs rest of the Table had length: {}", field.name, series.len(), num_rows)));
}
}

Expand All @@ -69,13 +69,19 @@ impl Table {
Ok(Table {
schema,
columns: columns?,
num_rows,
})
}

pub fn new_unchecked<S: Into<SchemaRef>>(schema: S, columns: Vec<Series>) -> Self {
pub fn new_unchecked<S: Into<SchemaRef>>(
schema: S,
columns: Vec<Series>,
num_rows: usize,
) -> Self {
Table {
schema: schema.into(),
columns,
num_rows,
}
}

Expand All @@ -87,16 +93,36 @@ impl Table {
let series = Series::empty(field_name, &field.dtype);
columns.push(series)
}
Ok(Table { schema, columns })
Ok(Table {
schema,
columns,
num_rows: 0,
})
}
None => Self::new(Schema::empty(), vec![]),
None => Self::new(Schema::empty(), vec![], 0),
}
}

pub fn from_columns(columns: Vec<Series>) -> DaftResult<Self> {
let fields = columns.iter().map(|s| s.field().clone()).collect();
let schema = Schema::new(fields)?;
Table::new(schema, columns)

let num_rows = if columns.is_empty() {
// Size of the table is 0 if no columns provided
0
} else if columns.iter().all(|s| s.len() == 1) {
// Size of the table is 1 if all columns have length 1
1
} else {
// Size of the table is the first non-len-1 column
columns
.iter()
.filter_map(|s| if s.len() != 1 { Some(s.len()) } else { None })
.next()
.unwrap()
};

Table::new(schema, columns, num_rows)
}

pub fn num_columns(&self) -> usize {
Expand All @@ -108,11 +134,7 @@ impl Table {
}

pub fn len(&self) -> usize {
if self.num_columns() == 0 {
0
} else {
self.get_column_by_index(0).unwrap().len()
}
self.num_rows
}

pub fn is_empty(&self) -> bool {
Expand All @@ -122,17 +144,16 @@ impl Table {
pub fn slice(&self, start: usize, end: usize) -> DaftResult<Self> {
let new_series: DaftResult<Vec<_>> =
self.columns.iter().map(|s| s.slice(start, end)).collect();
Ok(Table {
schema: self.schema.clone(),
columns: new_series?,
})
let new_num_rows = self.len().min(end - start);
Table::new(self.schema.clone(), new_series?, new_num_rows)
}

pub fn head(&self, num: usize) -> DaftResult<Self> {
if num >= self.len() {
return Ok(Table {
schema: self.schema.clone(),
columns: self.columns.clone(),
num_rows: self.len(),
});
}
self.slice(0, num)
Expand Down Expand Up @@ -253,12 +274,13 @@ impl Table {
Ok(Table {
schema: self.schema.clone(),
columns: new_series?,
num_rows: mask.len() - mask.null_count(),
})
}

pub fn take(&self, idx: &Series) -> DaftResult<Self> {
let new_series: DaftResult<Vec<_>> = self.columns.iter().map(|s| s.take(idx)).collect();
Ok(Table::new(self.schema.clone(), new_series?).unwrap())
Ok(Table::new(self.schema.clone(), new_series?, idx.len()).unwrap())
}

pub fn concat<T: AsRef<Table>>(tables: &[T]) -> DaftResult<Self> {
Expand Down Expand Up @@ -290,9 +312,11 @@ impl Table {
.collect();
new_series.push(Series::concat(series_to_cat.as_slice())?);
}

Ok(Table {
schema: first_table.schema.clone(),
columns: new_series,
num_rows: tables.iter().map(|t| t.as_ref().len()).sum(),
})
}

Expand Down Expand Up @@ -459,12 +483,25 @@ impl Table {
}
seen.insert(name.clone());
}
let schema = Schema::new(fields)?;
Table::new(schema, result_series)
let new_schema = Schema::new(fields)?;

Table::new(
new_schema,
result_series,
// TODO: Assuming that eval_expression_list doesn't change cardinality might be too strong of an assumption
// In that case we need to figure out a better way of deriving the length of this new Table instead of `self.len()`
self.len(),
)
}

pub fn as_physical(&self) -> DaftResult<Self> {
let new_series: DaftResult<Vec<_>> = self.columns.iter().map(|s| s.as_physical()).collect();
Table::from_columns(new_series?)
let new_series: Vec<Series> = self
.columns
.iter()
.map(|s| s.as_physical())
.collect::<DaftResult<Vec<_>>>()?;
let new_schema = Schema::new(new_series.iter().map(|s| s.field().clone()).collect())?;
Table::new(new_schema, new_series, self.len())
}

pub fn cast_to_schema(&self, schema: &Schema) -> DaftResult<Self> {
Expand Down Expand Up @@ -626,7 +663,7 @@ mod test {
a.field().clone().rename("a"),
b.field().clone().rename("b"),
])?;
let table = Table::new(schema, vec![a, b])?;
let table = Table::new(schema, vec![a, b], 3)?;
let e1 = col("a").add(col("b"));
let result = table.eval_expression(&e1)?;
assert_eq!(*result.data_type(), DataType::Float64);
Expand Down
9 changes: 6 additions & 3 deletions src/daft-table/src/ops/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ pub(super) fn hash_inner_join(
drop(lkeys);
drop(rkeys);

let num_rows = lidx.len();
join_series =
add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?;

Table::new(join_schema, join_series)
Table::new(join_schema, join_series, num_rows)
}

pub(super) fn hash_left_right_join(
Expand Down Expand Up @@ -218,10 +219,11 @@ pub(super) fn hash_left_right_join(
drop(lkeys);
drop(rkeys);

let num_rows = lidx.len();
join_series =
add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?;

Table::new(join_schema, join_series)
Table::new(join_schema, join_series, num_rows)
}

pub(super) fn hash_semi_anti_join(
Expand Down Expand Up @@ -449,8 +451,9 @@ pub(super) fn hash_outer_join(
drop(lkeys);
drop(rkeys);

let num_rows = lidx.len();
join_series =
add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?;

Table::new(join_schema, join_series)
Table::new(join_schema, join_series, num_rows)
}
3 changes: 2 additions & 1 deletion src/daft-table/src/ops/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ impl Table {
drop(ltable);
drop(rtable);

let num_rows = lidx.len();
join_series =
add_non_join_key_columns(self, right, lidx, ridx, left_on, right_on, join_series)?;

Table::new(join_schema, join_series)
Table::new(join_schema, join_series, num_rows)
}
}
2 changes: 1 addition & 1 deletion src/daft-table/src/ops/unpivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ impl Table {
])?)?;
let unpivot_series = [ids_series, vec![variable_series, value_series]].concat();

Table::new(unpivot_schema, unpivot_series)
Table::new(unpivot_schema, unpivot_series, unpivoted_len)
}
}
4 changes: 3 additions & 1 deletion src/daft-table/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ impl PyTable {
fields.push(Field::new(name.clone(), series.data_type().clone()));
columns.push(series.rename(name));
}

let num_rows = columns.first().map(|s| s.len()).unwrap_or(0);
if !columns.is_empty() {
let first = columns.first().unwrap();
for s in columns.iter().skip(1) {
Expand All @@ -443,7 +445,7 @@ impl PyTable {
}

Ok(PyTable {
table: Table::new(Schema::new(fields)?, columns)?,
table: Table::new(Schema::new(fields)?, columns, num_rows)?,
})
}

Expand Down

0 comments on commit ca55235

Please sign in to comment.