diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index dc333e8be7..2d2e64b9df 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -58,6 +58,9 @@ use unhex::spark_unhex; mod hex; use hex::spark_hex; +mod chr; +use chr::spark_chr; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -130,6 +133,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_xxhash64); make_comet_scalar_udf!("xxhash64", func, without data_type) } + "chr" => { + let func = Arc::new(spark_chr); + make_comet_scalar_udf!("chr", func, without data_type) + } sha if sha2_functions.contains(&sha) => { // Spark requires hex string as the result of sha2 functions, we have to wrap the // result of digest functions as hex string diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs b/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs new file mode 100644 index 0000000000..3d62d324c7 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/chr.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{cast::as_int64_array, exec_err, DataFusionError, Result, ScalarValue}; + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +pub fn spark_chr(args: &[ColumnarValue]) -> Result { + let chr_func = ChrFunc::default(); + chr_func.invoke(args) +} + +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| match core::char::from_u32(integer as u32) { + Some(integer) => Ok(integer.to_string()), + None => { + exec_err!("requested character too large for encoding.") + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub struct ChrFunc { + signature: Signature, +} + +impl Default for ChrFunc { + fn default() -> Self { + Self::new() + } +} + +impl ChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle_chr_fn(args) + } +} + +fn handle_chr_fn(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + match core::char::from_u32(value as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => exec_err!("requested character too large for encoding."), + } + } + ColumnarValue::Scalar(ScalarValue::Int64(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e69054dbde..b3638b0573 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -980,6 +980,23 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("Chr with null character") { + // test compatibility with Spark, spark supports chr(0) + Seq(false, true).foreach { dictionary => + withSQLConf( + "parquet.enable.dictionary" -> dictionary.toString, + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val table = "test0" + withTable(table) { + sql(s"create table $table(c9 int, c4 int) using parquet") + sql(s"insert into $table values(0, 0), (66, null), (null, 70), (null, null)") + val query = s"SELECT chr(c9), chr(c4) FROM $table" + checkSparkAnswerAndOperator(query) + } + } + } + } + test("InitCap") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {