Skip to content

Commit

Permalink
[BUG]: add count_matches and fix a bunch of str functions (#2946)
Browse files Browse the repository at this point in the history
a bunch of improvements to the str functions in SQL
  • Loading branch information
universalmind303 authored Sep 30, 2024
1 parent f10d4da commit b2dabf6
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub enum PadPlacement {
Right,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub struct Utf8NormalizeOptions {
pub remove_punct: bool,
pub lowercase: bool,
Expand Down
6 changes: 3 additions & 3 deletions src/daft-functions/src/count_matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use daft_dsl::{
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
struct CountMatchesFunction {
pub(super) whole_words: bool,
pub(super) case_sensitive: bool,
pub struct CountMatchesFunction {
pub whole_words: bool,
pub case_sensitive: bool,
}

#[typetag::serde]
Expand Down
10 changes: 5 additions & 5 deletions src/daft-functions/src/tokenize/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ fn tokenize_decode_series(
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct TokenizeDecodeFunction {
pub(super) tokens_path: String,
pub(super) io_config: Option<Arc<IOConfig>>,
pub(super) pattern: Option<String>,
pub(super) special_tokens: Option<String>,
pub struct TokenizeDecodeFunction {
pub tokens_path: String,
pub io_config: Option<Arc<IOConfig>>,
pub pattern: Option<String>,
pub special_tokens: Option<String>,
}

#[typetag::serde]
Expand Down
12 changes: 6 additions & 6 deletions src/daft-functions/src/tokenize/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ fn tokenize_encode_series(
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct TokenizeEncodeFunction {
pub(super) tokens_path: String,
pub(super) io_config: Option<Arc<IOConfig>>,
pub(super) pattern: Option<String>,
pub(super) special_tokens: Option<String>,
pub(super) use_special_tokens: bool,
pub struct TokenizeEncodeFunction {
pub tokens_path: String,
pub io_config: Option<Arc<IOConfig>>,
pub pattern: Option<String>,
pub special_tokens: Option<String>,
pub use_special_tokens: bool,
}

#[typetag::serde]
Expand Down
4 changes: 2 additions & 2 deletions src/daft-functions/src/tokenize/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use daft_dsl::{functions::ScalarFunction, ExprRef};
use daft_io::IOConfig;
use decode::TokenizeDecodeFunction;
use encode::TokenizeEncodeFunction;
pub use decode::TokenizeDecodeFunction;
pub use encode::TokenizeEncodeFunction;

mod bpe;
mod decode;
Expand Down
50 changes: 49 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,54 @@ impl SQLFunctionArguments {
pub fn get_named(&self, name: &str) -> Option<&ExprRef> {
self.named.get(name)
}

pub fn try_get_named<T: SQLLiteral>(&self, name: &str) -> Result<Option<T>, PlannerError> {
self.named
.get(name)
.map(|expr| T::from_expr(expr))
.transpose()
}
}

pub trait SQLLiteral {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized;
}

impl SQLLiteral for String {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
let e = expr
.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| PlannerError::invalid_operation("Expected a string literal"))?;
Ok(e.to_string())
}
}

impl SQLLiteral for i64 {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
expr.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal"))
}
}

impl SQLLiteral for bool {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
expr.as_literal()
.and_then(|lit| lit.as_bool())
.ok_or_else(|| PlannerError::invalid_operation("Expected a boolean literal"))
}
}

impl SQLFunctions {
Expand Down Expand Up @@ -216,7 +264,7 @@ impl SQLPlanner {
}
positional_args.insert(idx, self.try_unwrap_function_arg_expr(arg)?);
}
_ => unsupported_sql_err!("unsupported function argument type"),
other => unsupported_sql_err!("unsupported function argument type: {other}, valid function arguments for this function are: {expected_named:?}."),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ mod tests {
#[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")]
#[case::contains("select contains(utf8, 'a') as contains from tbl1")]
#[case::split("select split(utf8, '.') as split from tbl1")]
#[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")]
#[case::replace("select regexp_replace(utf8, 'a', 'b') as replace from tbl1")]
#[case::length("select length(utf8) as length from tbl1")]
#[case::lower("select lower(utf8) as lower from tbl1")]
#[case::upper("select upper(utf8) as upper from tbl1")]
Expand Down
Loading

0 comments on commit b2dabf6

Please sign in to comment.