Skip to content

Commit 8de0101

Browse files
[FEAT] connect: add modulus operator and withColumns support (#3351)
- Add `%` operator and `sum` function to unresolved functions - Implement withColumns transformation - Add test coverage for group by with modulus operation
1 parent 86523a0 commit 8de0101

File tree

4 files changed

+83
-1
lines changed

4 files changed

+83
-1
lines changed

src/daft-connect/src/translation/expr/unresolved_function.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,27 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
3232
.wrap_err("Failed to handle <= function"),
3333
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
3434
.wrap_err("Failed to handle >= function"),
35+
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
36+
.wrap_err("Failed to handle % function"),
37+
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
3538
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
3639
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
3740
n => bail!("Unresolved function {n} not yet supported"),
3841
}
3942
}
4043

44+
pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
45+
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
46+
Ok(arguments) => arguments,
47+
Err(arguments) => {
48+
bail!("requires exactly one argument; got {arguments:?}");
49+
}
50+
};
51+
52+
let [arg] = arguments;
53+
Ok(arg.sum())
54+
}
55+
4156
pub fn handle_binary_op(
4257
arguments: Vec<daft_dsl::ExprRef>,
4358
op: daft_dsl::Operator,

src/daft-connect/src/translation/logical_plan.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ use tracing::warn;
99

1010
use crate::translation::logical_plan::{
1111
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
12-
to_df::to_df,
12+
to_df::to_df, with_columns::with_columns,
1313
};
1414

1515
mod aggregate;
1616
mod local_relation;
1717
mod project;
1818
mod range;
1919
mod to_df;
20+
mod with_columns;
2021

2122
#[derive(Constructor)]
2223
pub struct Plan {
@@ -49,6 +50,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
4950
RelType::Aggregate(a) => {
5051
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
5152
}
53+
RelType::WithColumns(w) => {
54+
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
55+
}
5256
RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"),
5357
RelType::LocalRelation(l) => {
5458
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use eyre::bail;
2+
use spark_connect::{expression::ExprType, Expression};
3+
4+
use crate::translation::{to_daft_expr, to_logical_plan, Plan};
5+
6+
pub fn with_columns(with_columns: spark_connect::WithColumns) -> eyre::Result<Plan> {
7+
let spark_connect::WithColumns { input, aliases } = with_columns;
8+
9+
let Some(input) = input else {
10+
bail!("input is required");
11+
};
12+
13+
let mut plan = to_logical_plan(*input)?;
14+
15+
let daft_exprs: Vec<_> = aliases
16+
.into_iter()
17+
.map(|alias| {
18+
let expression = Expression {
19+
common: None,
20+
expr_type: Some(ExprType::Alias(Box::new(alias))),
21+
};
22+
23+
to_daft_expr(&expression)
24+
})
25+
.try_collect()?;
26+
27+
plan.builder = plan.builder.with_columns(daft_exprs)?;
28+
29+
Ok(plan)
30+
}

tests/connect/test_group_by.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
from pyspark.sql.functions import col
4+
5+
6+
def test_group_by(spark_session):
7+
# Create DataFrame from range(10)
8+
df = spark_session.range(10)
9+
10+
# Add a column that will have repeated values for grouping
11+
df = df.withColumn("group", col("id") % 3)
12+
13+
# Group by the new column and sum the ids in each group
14+
df_grouped = df.groupBy("group").sum("id")
15+
16+
# Convert to pandas to verify the sums
17+
df_grouped_pandas = df_grouped.toPandas()
18+
19+
# Sort by group to ensure consistent order for comparison
20+
df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True)
21+
22+
# Verify the expected sums for each group
23+
# group id
24+
# 0 2 15
25+
# 1 1 12
26+
# 2 0 18
27+
expected = {
28+
"group": [0, 1, 2],
29+
"id": [18, 12, 15], # todo(correctness): should this be "id" for value here?
30+
}
31+
32+
assert df_grouped_pandas["group"].tolist() == expected["group"]
33+
assert df_grouped_pandas["id"].tolist() == expected["id"]

0 commit comments

Comments
 (0)