diff --git a/src/daft-functions/src/minhash.rs b/src/daft-functions/src/minhash.rs index 48d13e0a65..6c000c4a1a 100644 --- a/src/daft-functions/src/minhash.rs +++ b/src/daft-functions/src/minhash.rs @@ -7,10 +7,10 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct MinHashFunction { - num_hashes: usize, - ngram_size: usize, - seed: u32, +pub struct MinHashFunction { + pub num_hashes: usize, + pub ngram_size: usize, + pub seed: u32, } #[typetag::serde] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 6172d2d382..2475cd1628 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use daft_dsl::ExprRef; +use hashing::SQLModuleHashing; use once_cell::sync::Lazy; use sqlparser::ast::{ Function, FunctionArg, FunctionArgExpr, FunctionArgOperator, FunctionArguments, @@ -18,6 +19,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SQLFunctions::new(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.register::(); functions.register::(); @@ -235,7 +237,10 @@ impl SQLPlanner { } } - fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult { + pub(crate) fn try_unwrap_function_arg_expr( + &self, + expr: &FunctionArgExpr, + ) -> SQLPlannerResult { match expr { FunctionArgExpr::Expr(expr) => self.plan_expr(expr), _ => unsupported_sql_err!("Wildcard function args not yet supported"), diff --git a/src/daft-sql/src/modules/hashing.rs b/src/daft-sql/src/modules/hashing.rs new file mode 100644 index 0000000000..4259ebd04a --- /dev/null +++ b/src/daft-sql/src/modules/hashing.rs @@ -0,0 +1,111 @@ +use daft_dsl::ExprRef; +use daft_functions::{ + hash::hash, + minhash::{minhash, MinHashFunction}, +}; +use sqlparser::ast::FunctionArg; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleHashing; + +impl SQLModule for SQLModuleHashing { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("hash", SQLHash); + parent.add_fn("minhash", SQLMinhash); + } +} + +pub struct SQLHash; + +impl SQLFunction for SQLHash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(hash(input, None)) + } + [input, seed] => { + let input = planner.plan_function_arg(input)?; + match seed { + FunctionArg::Named { name, arg, .. } if name.value == "seed" => { + let seed = planner.try_unwrap_function_arg_expr(arg)?; + Ok(hash(input, Some(seed))) + } + arg @ FunctionArg::Unnamed(_) => { + let seed = planner.plan_function_arg(arg)?; + Ok(hash(input, Some(seed))) + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } +} + +pub struct SQLMinhash; + +impl TryFrom for MinHashFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let num_hashes = args + .get_named("num_hashes") + .ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))? + .as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))? + as usize; + + let ngram_size = args + .get_named("ngram_size") + .ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))? + .as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))? + as usize; + let seed = args + .get_named("seed") + .map(|arg| { + arg.as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer")) + }) + .transpose()? + .unwrap_or(1) as u32; + Ok(Self { + num_hashes, + ngram_size, + seed, + }) + } +} + +impl SQLFunction for SQLMinhash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: MinHashFunction = + planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?; + + Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed)) + } + _ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"), + } + } +} diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 0f60ecbff9..989c401393 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -2,6 +2,7 @@ use crate::functions::SQLFunctions; pub mod aggs; pub mod float; +pub mod hashing; pub mod image; pub mod json; pub mod list; diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 4debfc0885..e3ae320094 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -1,4 +1,7 @@ +import pytest + import daft +from daft import col def test_nested(): @@ -20,3 +23,47 @@ def test_nested(): expected = df.with_column("try_this", df["A"] + 1).collect() assert actual.to_pydict() == expected.to_pydict() + + +def test_hash_exprs(): + df = daft.from_pydict( + { + "a": ["foo", "bar", "baz", "qux"], + "ints": [1, 2, 3, 4], + "floats": [1.5, 2.5, 3.5, 4.5], + } + ) + + actual = ( + daft.sql(""" + SELECT + hash(a) as hash_a, + hash(a, 0) as hash_a_0, + hash(a, seed:=0) as hash_a_seed_0, + minhash(a, num_hashes:=10, ngram_size:= 100, seed:=10) as minhash_a, + minhash(a, num_hashes:=10, ngram_size:= 100) as minhash_a_no_seed, + FROM df + """) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("a").hash().alias("hash_a"), + col("a").hash(0).alias("hash_a_0"), + col("a").hash(seed=0).alias("hash_a_seed_0"), + col("a").minhash(num_hashes=10, ngram_size=100, seed=10).alias("minhash_a"), + col("a").minhash(num_hashes=10, ngram_size=100).alias("minhash_a_no_seed"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + with pytest.raises(Exception, match="Invalid arguments for minhash"): + daft.sql("SELECT minhash() as hash_a FROM df").collect() + + with pytest.raises(Exception, match="num_hashes is required"): + daft.sql("SELECT minhash(a) as hash_a FROM df").collect()