Skip to content

Commit

Permalink
feat(ddl): support alter table cast tuple type column (#17310)
Browse files Browse the repository at this point in the history
* feat(ddl): support alter table cast tuple type column

* fix

* fix tests

* fix tests

* rename `array_tuple` as `arrays_zip`

* fix tests
  • Loading branch information
b41sh authored Jan 20, 2025
1 parent 3ea25c9 commit 7720353
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 18 deletions.
127 changes: 127 additions & 0 deletions src/query/functions/src/scalars/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use databend_common_expression::types::number::SimpleDomain;
use databend_common_expression::types::number::UInt64Type;
use databend_common_expression::types::AnyType;
use databend_common_expression::types::ArgType;
use databend_common_expression::types::ArrayColumn;
use databend_common_expression::types::ArrayType;
use databend_common_expression::types::BooleanType;
use databend_common_expression::types::DataType;
Expand Down Expand Up @@ -158,6 +159,132 @@ pub fn register(registry: &mut FunctionRegistry) {
}))
});

// Returns a merged array of tuples in which the nth tuple contains all nth values of input arrays.
registry.register_function_factory("arrays_zip", |_, args_type| {
if args_type.is_empty() {
return None;
}
let args_type = args_type.to_vec();

let inner_types: Vec<DataType> = args_type
.iter()
.map(|arg_type| {
let is_nullable = arg_type.is_nullable();
match arg_type.remove_nullable() {
DataType::Array(box inner_type) => {
if is_nullable {
inner_type.wrap_nullable()
} else {
inner_type.clone()
}
}
_ => arg_type.clone(),
}
})
.collect();
let return_type = DataType::Array(Box::new(DataType::Tuple(inner_types.clone())));
Some(Arc::new(Function {
signature: FunctionSignature {
name: "arrays_zip".to_string(),
args_type: args_type.clone(),
return_type,
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, args_domain| {
let inner_domains = args_domain
.iter()
.map(|arg_domain| match arg_domain {
Domain::Nullable(nullable_domain) => match &nullable_domain.value {
Some(box Domain::Array(Some(inner_domain))) => {
Domain::Nullable(NullableDomain {
has_null: nullable_domain.has_null,
value: Some(Box::new(*inner_domain.clone())),
})
}
_ => Domain::Nullable(nullable_domain.clone()),
},
Domain::Array(Some(box inner_domain)) => inner_domain.clone(),
_ => arg_domain.clone(),
})
.collect();
FunctionDomain::Domain(Domain::Array(Some(Box::new(Domain::Tuple(
inner_domains,
)))))
}),
eval: Box::new(move |args, ctx| {
let len = args.iter().find_map(|arg| match arg {
Value::Column(col) => Some(col.len()),
_ => None,
});

let mut offset = 0;
let mut offsets = Vec::new();
offsets.push(0);
let tuple_type = DataType::Tuple(inner_types.clone());
let mut builder = ColumnBuilder::with_capacity(&tuple_type, 0);
for i in 0..len.unwrap_or(1) {
let mut is_diff_len = false;
let mut array_len = None;
for arg in args {
let value = unsafe { arg.index_unchecked(i) };
if let ScalarRef::Array(col) = value {
if let Some(array_len) = array_len {
if array_len != col.len() {
is_diff_len = true;
let err = format!(
"array length must be equal, but got {} and {}",
array_len,
col.len()
);
ctx.set_error(builder.len(), err);
offsets.push(offset);
break;
}
} else {
array_len = Some(col.len());
}
}
}
if is_diff_len {
continue;
}
let array_len = array_len.unwrap_or(1);
for j in 0..array_len {
let mut tuple_values = Vec::with_capacity(args.len());
for arg in args {
let value = unsafe { arg.index_unchecked(i) };
match value {
ScalarRef::Array(col) => {
let tuple_value = unsafe { col.index_unchecked(j) };
tuple_values.push(tuple_value.to_owned());
}
_ => {
tuple_values.push(value.to_owned());
}
}
}
let tuple_value = Scalar::Tuple(tuple_values);
builder.push(tuple_value.as_ref());
}
offset += array_len as u64;
offsets.push(offset);
}

match len {
Some(_) => {
let array_column = ArrayColumn {
values: builder.build(),
offsets: offsets.into(),
};
Value::Column(Column::Array(Box::new(array_column)))
}
_ => Value::Scalar(Scalar::Array(builder.build())),
}
}),
},
}))
});

registry.register_1_arg::<EmptyArrayType, NumberType<u8>, _, _>(
"length",
|_, _| FunctionDomain::Domain(SimpleDomain { min: 0, max: 0 }),
Expand Down
8 changes: 8 additions & 0 deletions src/query/functions/tests/it/scalars/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fn test_array() {
test_array_kurtosis(file);
test_array_skewness(file);
test_array_sort(file);
test_arrays_zip(file);
}

fn test_create(file: &mut impl Write) {
Expand Down Expand Up @@ -731,3 +732,10 @@ fn test_array_sort(file: &mut impl Write) {
&[],
);
}

fn test_arrays_zip(file: &mut impl Write) {
run_ast(file, "arrays_zip(NULL, NULL)", &[]);
run_ast(file, "arrays_zip(1, 2, 'a')", &[]);
run_ast(file, "arrays_zip([1,2,3], ['a','b','c'], 10)", &[]);
run_ast(file, "arrays_zip([1,2,3], ['a','b'], 10)", &[]);
}
35 changes: 35 additions & 0 deletions src/query/functions/tests/it/scalars/testdata/array.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2382,3 +2382,38 @@ output domain : [{0.0..=5.6} ∪ {NULL}]
output : [5.6, 3.4, 2.2, 1.2, NULL, NULL]


ast : arrays_zip(NULL, NULL)
raw expr : arrays_zip(NULL, NULL)
checked expr : arrays_zip<NULL, NULL>(NULL, NULL)
optimized expr : [(NULL, NULL)]
output type : Array(Tuple(NULL, NULL))
output domain : [({NULL}, {NULL})]
output : [(NULL, NULL)]


ast : arrays_zip(1, 2, 'a')
raw expr : arrays_zip(1, 2, 'a')
checked expr : arrays_zip<UInt8, UInt8, String>(1_u8, 2_u8, "a")
optimized expr : [(1, 2, 'a')]
output type : Array(Tuple(UInt8, UInt8, String))
output domain : [({1..=1}, {2..=2}, {"a"..="a"})]
output : [(1, 2, 'a')]


ast : arrays_zip([1,2,3], ['a','b','c'], 10)
raw expr : arrays_zip(array(1, 2, 3), array('a', 'b', 'c'), 10)
checked expr : arrays_zip<Array(UInt8), Array(String), UInt8>(array<T0=UInt8><T0, T0, T0>(1_u8, 2_u8, 3_u8), array<T0=String><T0, T0, T0>("a", "b", "c"), 10_u8)
optimized expr : [(1, 'a', 10), (2, 'b', 10), (3, 'c', 10)]
output type : Array(Tuple(UInt8, String, UInt8))
output domain : [({1..=3}, {"a"..="c"}, {10..=10})]
output : [(1, 'a', 10), (2, 'b', 10), (3, 'c', 10)]


error:
--> SQL:1:1
|
1 | arrays_zip([1,2,3], ['a','b'], 10)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ array length must be equal, but got 3 and 2 while evaluating function `arrays_zip([1, 2, 3], ['a', 'b'], 10)` in expr `arrays_zip(array(1, 2, 3), array('a', 'b'), 10)`



Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Functions overloads:
1 array_unique(Array(Nothing) NULL) :: UInt64 NULL
2 array_unique(Array(T0)) :: UInt64
3 array_unique(Array(T0) NULL) :: UInt64 NULL
0 arrays_zip FACTORY
0 as_array(Variant) :: Variant NULL
1 as_array(Variant NULL) :: Variant NULL
0 as_boolean(Variant) :: Boolean NULL
Expand Down
Loading

0 comments on commit 7720353

Please sign in to comment.