Skip to content

Commit

Permalink
fix tests and hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Sep 26, 2024
1 parent 185f634 commit 4d5ae1e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ impl SQLPlanner {
}
}

fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult<ExprRef> {
pub(crate) fn try_unwrap_function_arg_expr(
&self,
expr: &FunctionArgExpr,
) -> SQLPlannerResult<ExprRef> {
match expr {
FunctionArgExpr::Expr(expr) => self.plan_expr(expr),
_ => unsupported_sql_err!("Wildcard function args not yet supported"),
Expand Down
13 changes: 9 additions & 4 deletions src/daft-sql/src/modules/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl SQLFunction for SQLHash {
[input, seed] => {
let input = planner.plan_function_arg(input)?;
match seed {
arg @ FunctionArg::Named { name, .. } if name.value == "seed" => {
let seed = planner.plan_function_arg(arg)?;
FunctionArg::Named { name, arg, .. } if name.value == "seed" => {
let seed = planner.try_unwrap_function_arg_expr(arg)?;
Ok(hash(input, Some(seed)))
}
arg @ FunctionArg::Unnamed(_) => {
Expand All @@ -61,12 +61,17 @@ impl TryFrom<SQLFunctionArguments> for MinHashFunction {
fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let num_hashes = args
.get_named("num_hashes")
.and_then(|arg| arg.as_literal().and_then(|lit| lit.as_i64()))
.ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))?
as usize;

let ngram_size = args
.get_named("ngram_size")
.and_then(|arg| arg.as_literal().and_then(|lit| lit.as_i64()))
.ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))?
as usize;
let seed = args
Expand Down
47 changes: 47 additions & 0 deletions tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

import daft
from daft import col


def test_nested():
Expand All @@ -20,3 +23,47 @@ def test_nested():
expected = df.with_column("try_this", df["A"] + 1).collect()

assert actual.to_pydict() == expected.to_pydict()


def test_hash_exprs():
df = daft.from_pydict(
{
"a": ["foo", "bar", "baz", "qux"],
"ints": [1, 2, 3, 4],
"floats": [1.5, 2.5, 3.5, 4.5],
}
)

actual = (
daft.sql("""
SELECT
hash(a) as hash_a,
hash(a, 0) as hash_a_0,
hash(a, seed:=0) as hash_a_seed_0,
minhash(a, num_hashes:=10, ngram_size:= 100, seed:=10) as minhash_a,
minhash(a, num_hashes:=10, ngram_size:= 100) as minhash_a_no_seed,
FROM df
""")
.collect()
.to_pydict()
)

expected = (
df.select(
col("a").hash().alias("hash_a"),
col("a").hash(0).alias("hash_a_0"),
col("a").hash(seed=0).alias("hash_a_seed_0"),
col("a").minhash(num_hashes=10, ngram_size=100, seed=10).alias("minhash_a"),
col("a").minhash(num_hashes=10, ngram_size=100).alias("minhash_a_no_seed"),
)
.collect()
.to_pydict()
)

assert actual == expected

with pytest.raises(Exception, match="Invalid arguments for minhash"):
daft.sql("SELECT minhash() as hash_a FROM df").collect()

with pytest.raises(Exception, match="num_hashes is required"):
daft.sql("SELECT minhash(a) as hash_a FROM df").collect()

0 comments on commit 4d5ae1e

Please sign in to comment.