Skip to content

Commit

Permalink
[CHORE] Return &str for expression name (#2224)
Browse files Browse the repository at this point in the history
We return a `DaftResult<&str>` for expression names, but currently
there's no case that would raise an error. Having the return type as
`&str` should be sufficient.
  • Loading branch information
colin-ho authored May 6, 2024
1 parent 38ab44a commit 29d310b
Show file tree
Hide file tree
Showing 15 changed files with 30 additions and 40 deletions.
22 changes: 11 additions & 11 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub fn binary_op(op: Operator, left: ExprRef, right: ExprRef) -> ExprRef {
}

impl AggExpr {
pub fn name(&self) -> DaftResult<&str> {
pub fn name(&self) -> &str {
use AggExpr::*;
match self {
Count(expr, ..)
Expand Down Expand Up @@ -641,19 +641,19 @@ impl Expr {
match self {
Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)),
Agg(agg_expr) => agg_expr.to_field(schema),
Cast(expr, dtype) => Ok(Field::new(expr.name()?, dtype.clone())),
Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())),
Column(name) => Ok(schema.get_field(name).cloned()?),
Not(expr) => {
let child_field = expr.to_field(schema)?;
match child_field.dtype {
DataType::Boolean => Ok(Field::new(expr.name()?, DataType::Boolean)),
DataType::Boolean => Ok(Field::new(expr.name(), DataType::Boolean)),
_ => Err(DaftError::TypeError(format!(
"Expected argument to be a Boolean expression, but received {child_field}",
))),
}
}
IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)),
NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)),
FillNull(expr, fill_value) => {
let expr_field = expr.to_field(schema)?;
let fill_value_field = fill_value.to_field(schema)?;
Expand Down Expand Up @@ -736,7 +736,7 @@ impl Expr {
match predicate.as_ref() {
Expr::Literal(lit::LiteralValue::Boolean(true)) => if_true.to_field(schema),
Expr::Literal(lit::LiteralValue::Boolean(false)) => {
Ok(if_false.to_field(schema)?.rename(if_true.name()?))
Ok(if_false.to_field(schema)?.rename(if_true.name()))
}
_ => {
let if_true_field = if_true.to_field(schema)?;
Expand All @@ -751,21 +751,21 @@ impl Expr {
}
}

pub fn name(&self) -> DaftResult<&str> {
pub fn name(&self) -> &str {
use Expr::*;
match self {
Alias(.., name) => Ok(name.as_ref()),
Alias(.., name) => name.as_ref(),
Agg(agg_expr) => agg_expr.name(),
Cast(expr, ..) => expr.name(),
Column(name) => Ok(name.as_ref()),
Column(name) => name.as_ref(),
Not(expr) => expr.name(),
IsNull(expr) => expr.name(),
NotNull(expr) => expr.name(),
FillNull(expr, ..) => expr.name(),
IsIn(expr, ..) => expr.name(),
Literal(..) => Ok("literal"),
Literal(..) => "literal",
Function { func, inputs } => match func {
FunctionExpr::Struct(StructExpr::Get(name)) => Ok(name),
FunctionExpr::Struct(StructExpr::Get(name)) => name,
_ => inputs.first().unwrap().name(),
},
BinaryOp {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/functions/list/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl FunctionEvaluator for CountEvaluator {
match input_field.dtype {
DataType::List(_) | DataType::FixedSizeList(_, _) => match expr {
FunctionExpr::List(ListExpr::Count(_)) => {
Ok(Field::new(input.name()?, DataType::UInt64))
Ok(Field::new(input.name(), DataType::UInt64))
}
_ => panic!("Expected List Count Expr, got {expr}"),
},
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/functions/python/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl FunctionEvaluator for PythonUDF {
[] => Err(DaftError::ValueError(
"Cannot run UDF with 0 expression arguments".into(),
)),
[first, ..] => Ok(Field::new(first.name()?, self.return_dtype.clone())),
[first, ..] => Ok(Field::new(first.name(), self.return_dtype.clone())),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl PyExpr {
}

pub fn name(&self) -> PyResult<&str> {
Ok(self.expr.name()?)
Ok(self.expr.name())
}

pub fn to_sql(&self) -> PyResult<Option<String>> {
Expand Down
5 changes: 1 addition & 4 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,7 @@ impl LogicalPlanBuilder {
) -> DaftResult<Self> {
err_if_agg("with_columns", &columns)?;

let new_col_names = columns
.iter()
.map(|e| e.name())
.collect::<DaftResult<HashSet<&str>>>()?;
let new_col_names = columns.iter().map(|e| e.name()).collect::<HashSet<&str>>();

let mut exprs = self
.schema()
Expand Down
6 changes: 1 addition & 5 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ impl Join {
// Schema inference ported from existing behaviour for parity,
// but contains bug https://github.com/Eventual-Inc/Daft/issues/1294
let output_schema = {
let left_join_keys = left_on
.iter()
.map(|e| e.name())
.collect::<common_error::DaftResult<HashSet<_>>>()
.context(CreationSnafu)?;
let left_join_keys = left_on.iter().map(|e| e.name()).collect::<HashSet<_>>();
let left_schema = &left.schema().fields;
let fields = left_schema
.iter()
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ impl Project {
// The substitution can unintentionally change the expression's name
// (since the name depends on the first column referenced, which can be substituted away)
// so re-alias the original name here if it has changed.
let old_name = e.name().unwrap();
if new_expr.name().unwrap() != old_name {
let old_name = e.name();
if new_expr.name() != old_name {
new_expr.alias(old_name)
} else {
new_expr.clone()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,7 @@ impl OptimizerRule for PushDownFilter {
let projection_input_mapping = child_project
.projection
.iter()
.filter_map(|e| {
e.input_mapping()
.map(|s| (e.name().unwrap().to_string(), col(s)))
})
.filter_map(|e| e.input_mapping().map(|s| (e.name().to_string(), col(s))))
.collect::<HashMap<String, ExprRef>>();
// Split predicate expressions into those that don't depend on projection compute (can_push) and those
// that do (can_not_push).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl PushDownProjection {
.flat_map(|e| {
e.input_mapping().map_or_else(
// None means computation required -> Some(colname)
|| Some(e.name().unwrap().to_string()),
|| Some(e.name().to_string()),
// Some(computation not required) -> None
|_| None,
)
Expand Down Expand Up @@ -107,7 +107,7 @@ impl PushDownProjection {
let upstream_names_to_exprs = upstream_projection
.projection
.iter()
.map(|e| (e.name().unwrap().to_string(), e.clone()))
.map(|e| (e.name().to_string(), e.clone()))
.collect::<HashMap<_, _>>();

// Merge the projections by applying the upstream expression substitutions
Expand Down Expand Up @@ -184,7 +184,7 @@ impl PushDownProjection {
let pruned_upstream_projections = upstream_projection
.projection
.iter()
.filter(|&e| required_columns.contains(e.name().unwrap()))
.filter(|&e| required_columns.contains(e.name()))
.cloned()
.collect::<Vec<_>>();

Expand All @@ -211,7 +211,7 @@ impl PushDownProjection {
let pruned_aggregate_exprs = aggregate
.aggregations
.iter()
.filter(|&e| required_columns.contains(e.name().unwrap()))
.filter(|&e| required_columns.contains(e.name()))
.cloned()
.collect::<Vec<_>>();

Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/physical_ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl Explode {
.flat_map(get_required_columns)
.collect::<HashSet<String>>();
for expr in to_explode {
let newname = expr.name().unwrap().to_string();
let newname = expr.name().to_string();
// if we clobber one of the required columns for the clustering_spec, invalidate it.
if required_cols_for_clustering_spec.contains(&newname) {
return ClusteringSpec::Unknown(UnknownClusteringConfig::new(
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/physical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Project {
let mut old_colname_to_new_colname = IndexMap::new();
for expr in projection {
if let Some(oldname) = expr.input_mapping() {
let newname = expr.name().unwrap().to_string();
let newname = expr.name().to_string();
// Add the oldname -> newname mapping,
// but don't overwrite any existing identity mappings (e.g. "a" -> "a").
if old_colname_to_new_colname.get(&oldname) != Some(&oldname) {
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ pub(super) fn translate_single_logical_node(
};
let join_strategy = join_strategy.unwrap_or_else(|| {
let is_primitive = |exprs: &Vec<ExprRef>| {
exprs.iter().map(|e| e.name().unwrap()).all(|col| {
exprs.iter().map(|e| e.name()).all(|col| {
let dtype = &output_schema.get_field(col).unwrap().dtype;
dtype.is_integer()
|| dtype.is_floating()
Expand Down Expand Up @@ -700,7 +700,7 @@ fn populate_aggregation_stages(
let mut final_exprs: Vec<ExprRef> = group_by.to_vec();

for agg_expr in aggregations {
let output_name = agg_expr.name().unwrap();
let output_name = agg_expr.name();
match agg_expr {
Count(e, mode) => {
let count_id = agg_expr.semantic_id(schema).id;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ impl Table {
} => match predicate.as_ref() {
Expr::Literal(LiteralValue::Boolean(true)) => self.eval_expression(if_true),
Expr::Literal(LiteralValue::Boolean(false)) => {
Ok(self.eval_expression(if_false)?.rename(if_true.name()?))
Ok(self.eval_expression(if_false)?.rename(if_true.name()))
}
_ => {
let if_true_series = self.eval_expression(if_true)?;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Table {
return Err(DaftError::ValueError(format!("ListExpr::Explode function expression must have one input only, received: {}", inputs.len())));
}
let expr = inputs.first().unwrap();
let exploded_name = expr.name()?;
let exploded_name = expr.name();
let evaluated = self.eval_expression(expr)?;
if !matches!(
evaluated.data_type(),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/ops/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl Table {
let right_names = right_on.iter().map(|e| e.name());
let zipped_names: DaftResult<_> = left_names
.zip(right_names)
.map(|(l, r)| Ok((l?, r?)))
.map(|(l, r)| Ok((l, r)))
.collect();
let zipped_names: Vec<(&str, &str)> = zipped_names?;
let right_to_left_keys: HashMap<&str, &str> =
Expand Down

0 comments on commit 29d310b

Please sign in to comment.