From 1f924f57316f8762a9fba1728e5d62d89bae84ec Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Fri, 13 Jan 2023 09:04:20 -0800 Subject: [PATCH 1/9] Enable `concat()` string function to support multiple string arguments (#200) * Refactor concat() to support multiple string arguments Signed-off-by: Margarit Hakobyan --- .../function/BuiltinFunctionRepository.java | 7 ++ .../sql/expression/function/FunctionDSL.java | 103 +++++++++++++++++- .../function/SerializableVarargsFunction.java | 22 ++++ .../function/VarargsFunctionResolver.java | 69 ++++++++++++ .../sql/expression/text/TextFunction.java | 19 ++-- .../BuiltinFunctionRepositoryTest.java | 31 ++++++ .../function/FunctionDSLTestBase.java | 2 + .../function/FunctionDSLimplVarargsTest.java | 32 ++++++ .../function/VarargsFunctionResolverTest.java | 81 ++++++++++++++ .../sql/expression/text/TextFunctionTest.java | 20 ++++ docs/user/dql/functions.rst | 16 +-- docs/user/ppl/functions/string.rst | 16 +-- .../opensearch/sql/ppl/TextFunctionIT.java | 4 +- .../opensearch/sql/sql/TextFunctionIT.java | 1 + .../expressions/text_functions.txt | 4 +- 15 files changed, 399 insertions(+), 28 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java create mode 100644 core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 71fd19991e4..f3a5590b98e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -182,6 +182,13 @@ private FunctionBuilder getFunctionBuilder( if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { return funcBuilder; } + // For functions with variable number of args (ex: concat()) + // targetTypes will always be empty (as the function signature is not fixed), + // and failure will occur. + // So, in this case sourceTypes are passed instead of targetTypes to address that. + if (functionResolverMap.get(functionName) instanceof VarargsFunctionResolver) { + return castArguments(sourceTypes, sourceTypes, funcBuilder); + } return castArguments(sourceTypes, targetTypes, funcBuilder); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index d94d7cdf601..c2814e9f4e2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -9,12 +9,14 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; @@ -58,6 +60,39 @@ public static DefaultFunctionResolver define(FunctionName functionName, List< return builder.build(); } + /** + * Define varargs function with implementation. + * + * @param functionName function name. + * @param functions a list of function implementation. + * @return VarargsFunctionResolver. + */ + public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, + SerializableFunction>... functions) { + return defineVarargsFunction(functionName, List.of(functions)); + } + + /** + * Define varargs function with implementation. + * + * @param functionName function name. + * @param functions a list of function implementation. + * @return VarargsFunctionResolver. + */ + public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, List< + SerializableFunction>> functions) { + + VarargsFunctionResolver.VarargsFunctionResolverBuilder builder = + VarargsFunctionResolver.builder(); + builder.functionName(functionName); + for (SerializableFunction> func + : functions) { + Pair functionBuilder = func.apply(functionName); + builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue()); + } + return builder.build(); + } /** * Implementation of no args function that uses FunctionProperties. @@ -212,6 +247,56 @@ public static SerializableFunction function.apply(arg), returnType, argsType); } + /** + * Varargs Function Implementation. + * This implementation considers 1...n args of the same type. + * + * @param function {@link ExprValue} based varargs function. + * @param returnType return type. + * @param argsType argument type. + * @return Varargs Function Implementation. + */ + public static SerializableFunction> impl( + SerializableVarargsFunction function, + ExprType returnType, + ExprType argsType, + boolean withVarargs) { + + return functionName -> { + AtomicInteger argsCount = new AtomicInteger(0); + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + argsCount.set(arguments.size()); + ExprValue[] args = arguments.stream() + .map(arg -> arg.valueOf(valueEnv)) + .collect(Collectors.toList()) + .toArray(new ExprValue[arguments.size()]); + + return function.apply(args); + } + + @Override + public ExprType type() { + return returnType; + } + + @Override + public String toString() { + return String.format("%s(%s)", functionName, arguments.stream() + .map(Object::toString) + .collect(Collectors.joining(", "))); + } + }; + ExprCoreType[] argsTypes = new ExprCoreType[argsCount.get()]; + Arrays.fill(argsTypes, argsType); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(argsTypes)); + return Pair.of(functionSignature, functionBuilder); + }; + } + /** * Binary Function Implementation. * @@ -323,13 +408,29 @@ public SerializableTriFunction nullM }; } + /** + * Wrapper the varargs ExprValue function with default NULL and MISSING handling. + */ + public SerializableVarargsFunction nullMissingHandling( + SerializableVarargsFunction function, boolean withVarargs) { + return (args) -> { + if (Arrays.stream(args).anyMatch(ExprValue::isMissing)) { + return ExprValueUtils.missingValue(); + } else if (Arrays.stream(args).anyMatch(ExprValue::isNull)) { + return ExprValueUtils.nullValue(); + } else { + return function.apply(args); + } + }; + } + /** * Wrapper the unary ExprValue function that is aware of FunctionProperties, * with default NULL and MISSING handling. */ public static SerializableBiFunction nullMissingHandlingWithProperties( - SerializableBiFunction implementation) { + SerializableBiFunction implementation) { return (functionProperties, v1) -> { if (v1.isMissing()) { return ExprValueUtils.missingValue(); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java new file mode 100644 index 00000000000..3c0c07fa79f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.expression.function; + +import java.io.Serializable; + +/** + * Serializable Varargs Function. + */ +public interface SerializableVarargsFunction extends Serializable { + /** + * Applies this function to the given arguments. + * + * @param t the function argument + * @return the function result + */ + R apply(T... t); +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java new file mode 100644 index 00000000000..5af054fb77e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.AbstractMap; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Singular; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +/** + * The Function Resolver hold the overload {@link FunctionBuilder} implementation. + * is composed by {@link FunctionName} which identified the function name + * and a map of {@link FunctionSignature} and {@link FunctionBuilder} + * to represent the overloaded implementation + */ +@Builder +@RequiredArgsConstructor +public class VarargsFunctionResolver implements FunctionResolver { + @Getter + private final FunctionName functionName; + @Singular("functionBundle") + private final Map functionBundle; + + /** + * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. + * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. + * If applying the widening rule, found the most match one, return it. + * If nothing found, throw {@link ExpressionEvaluationException} + * + * @return function signature and its builder + */ + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + PriorityQueue> functionMatchQueue = new PriorityQueue<>( + Map.Entry.comparingByKey()); + + for (FunctionSignature functionSignature : functionBundle.keySet()) { + functionMatchQueue.add( + new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), + functionSignature)); + } + Map.Entry bestMatchEntry = functionMatchQueue.peek(); + if (unresolvedSignature.getParamTypeList().isEmpty()) { + throw new ExpressionEvaluationException( + String.format("%s function expected %s, but get %s", functionName, + formatFunctions(functionBundle.keySet()), + unresolvedSignature.formatTypes() + )); + } else { + FunctionSignature resolvedSignature = bestMatchEntry.getValue(); + return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); + } + } + + private String formatFunctions(Set functionSignatures) { + return functionSignatures.stream().map(FunctionSignature::formatTypes) + .collect(Collectors.joining(",", "{", "}")); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 25eb25489ca..e57c696785c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -9,9 +9,12 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.function.FunctionDSL.define; +import static org.opensearch.sql.expression.function.FunctionDSL.defineVarargsFunction; import static org.opensearch.sql.expression.function.FunctionDSL.impl; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; +import java.util.Arrays; +import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprStringValue; @@ -22,7 +25,7 @@ import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; - +import org.opensearch.sql.expression.function.VarargsFunctionResolver; /** * The definition of text functions. @@ -141,16 +144,16 @@ private DefaultFunctionResolver upper() { } /** - * TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710 - * Extend to accept variable argument amounts. * Concatenates a list of Strings. * Supports following signatures: - * (STRING, STRING) -> STRING + * (STRING, STRING, ...., STRING) -> STRING */ - private DefaultFunctionResolver concat() { - return define(BuiltinFunctionName.CONCAT.getName(), - impl(nullMissingHandling((str1, str2) -> - new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING)); + private VarargsFunctionResolver concat() { + return defineVarargsFunction(BuiltinFunctionName.CONCAT.getName(), + impl(nullMissingHandling(strings -> + new ExprStringValue(Arrays.stream(strings) + .map(ExprValue::stringValue) + .collect(Collectors.joining())), true), STRING, STRING, true)); } /** diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index 8bba3bd9b95..28fbb7bd974 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -238,6 +238,37 @@ void resolve_unregistered() { assertEquals("unsupported function name: unknown", exception.getMessage()); } + @Test + void resolve_should_cast_arguments_for_varargs_function() { + FunctionSignature unresolvedSignature = new FunctionSignature( + mockFunctionName, ImmutableList.of(STRING, STRING, STRING)); + FunctionSignature resolvedSignature = new FunctionSignature( + mockFunctionName, Collections.emptyList()); + + VarargsFunctionResolver varargsFunctionResolver = mock(VarargsFunctionResolver.class); + FunctionBuilder funcBuilder = mock(FunctionBuilder.class); + + when(mockFunctionName.getFunctionName()).thenReturn("mockFunction"); + when(mockExpression.toString()).thenReturn("string"); + when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap); + when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); + when(mockMap.containsKey(eq(mockFunctionName))).thenReturn(true); + when(mockMap.get(eq(mockFunctionName))).thenReturn(varargsFunctionResolver); + when(varargsFunctionResolver.resolve(eq(unresolvedSignature))).thenReturn( + Pair.of(resolvedSignature, funcBuilder)); + repo.register(varargsFunctionResolver); + // Relax unnecessary stubbing check because error case test doesn't call this + lenient().doAnswer(invocation -> + new FakeFunctionExpression(mockFunctionName, invocation.getArgument(1)) + ).when(funcBuilder).apply(eq(functionProperties), any()); + + FunctionImplementation function = + repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), unresolvedSignature) + .apply(functionProperties, + ImmutableList.of(mockExpression, mockExpression, mockExpression)); + assertEquals("mockFunction(string, string, string)", function.toString()); + } + private FunctionSignature registerFunctionResolver(FunctionName funcName, ExprType sourceType, ExprType targetType) { diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java index 63c6ea33296..5bcc9f9e89f 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java @@ -59,6 +59,8 @@ public int compareTo(ExprValue o) { twoArgs = (v1, v2) -> ANY; static final SerializableTriFunction threeArgs = (v1, v2, v3) -> ANY; + static final SerializableVarargsFunction + varrgs = (v1) -> ANY; @Mock FunctionProperties mockProperties; } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java new file mode 100644 index 00000000000..cee8889359b --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplVarargsTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(varrgs, ANY_TYPE, ANY_TYPE, true); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java new file mode 100644 index 00000000000..dafec21f402 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.WideningTypeRule; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +class VarargsFunctionResolverTest { + @Mock + private FunctionSignature exactlyMatchFS; + @Mock + private FunctionSignature bestMatchFS; + @Mock + private FunctionSignature leastMatchFS; + @Mock + private FunctionSignature notMatchFS; + @Mock + private FunctionSignature functionSignature; + @Mock + private FunctionBuilder exactlyMatchBuilder; + @Mock + private FunctionBuilder bestMatchBuilder; + @Mock + private FunctionBuilder leastMatchBuilder; + @Mock + private FunctionBuilder notMatchBuilder; + + private FunctionName functionName = FunctionName.of("test_function"); + + @Test + void resolve_function_signature_exactly_match() { + when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING, STRING, STRING)); + VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, + ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); + + assertEquals(exactlyMatchBuilder, resolver.resolve(functionSignature).getValue()); + } + + @Test + void resolve_function_signature_best_match() { + when(functionSignature.match(bestMatchFS)).thenReturn(1); + when(functionSignature.match(leastMatchFS)).thenReturn(2); + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING, STRING, STRING)); + VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, + ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); + + assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); + } + + @Test + void resolve_function_not_match() { + when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); + // accepts 1 or more arguments + when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList()); + VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, + ImmutableMap.of(notMatchFS, notMatchBuilder)); + + assertThrows(ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index 515b436c826..a0fd3b8c3cd 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase { private static List> CONCAT_STRING_LISTS = ImmutableList.of( ImmutableList.of("hello", "world"), ImmutableList.of("123", "5325")); + private static List> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of( + ImmutableList.of("he", "llo", "wo", "rld", "!"), + ImmutableList.of("0", "123", "53", "25", "7")); interface SubstrSubstring { FunctionExpression getFunction(SubstringInfo strInfo); @@ -228,6 +231,7 @@ public void upper() { @Test void concat() { CONCAT_STRING_LISTS.forEach(this::testConcatString); + CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString); when(nullRef.type()).thenReturn(STRING); when(missingRef.type()).thenReturn(STRING); @@ -446,6 +450,22 @@ void testConcatString(List strings, String delim) { assertEquals(expected, eval(expression).stringValue()); } + void testConcatMultipleString(List strings) { + String expected = null; + if (strings.stream().noneMatch(Objects::isNull)) { + expected = String.join("", strings); + } + + FunctionExpression expression = DSL.concat( + DSL.literal(strings.get(0)), + DSL.literal(strings.get(1)), + DSL.literal(strings.get(2)), + DSL.literal(strings.get(3)), + DSL.literal(strings.get(4))); + assertEquals(STRING, expression.type()); + assertEquals(expected, eval(expression).stringValue()); + } + void testLengthString(String str) { FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str))); assertEquals(INTEGER, expression.type()); diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index f433845bb3c..2ce44d91e53 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2614,21 +2614,21 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together. +Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. -Argument type: STRING, STRING +Argument type: STRING, STRING, ...., STRING Return type: STRING Example:: - os> SELECT CONCAT('hello', 'world') + os> SELECT CONCAT('hello', 'world'), CONCAT('hello ', 'whole ', 'world', '!'); fetched rows / total rows = 1/1 - +----------------------------+ - | CONCAT('hello', 'world') | - |----------------------------| - | helloworld | - +----------------------------+ + +----------------------------+--------------------------------------------+ + | CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') | + |----------------------------+--------------------------------------------| + | helloworld | hello whole world! | + +----------------------------+--------------------------------------------+ CONCAT_WS diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index 0503759cbdf..315a46616ca 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -14,21 +14,21 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together. +Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. -Argument type: STRING, STRING +Argument type: STRING, STRING, ...., STRING Return type: STRING Example:: - os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')` + os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')` fetched rows / total rows = 1/1 - +----------------------------+ - | CONCAT('hello', 'world') | - |----------------------------| - | helloworld | - +----------------------------+ + +----------------------------+--------------------------------------------+ + | CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') | + |----------------------------+--------------------------------------------| + | helloworld | hello whole world! | + +----------------------------+--------------------------------------------+ CONCAT_WS diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java index 7c48bceab0e..024f190bee3 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java @@ -99,8 +99,8 @@ public void testLtrim() throws IOException { @Test public void testConcat() throws IOException { - verifyQuery("concat", "", ", 'there'", - "hellothere", "worldthere", "helloworldthere"); + verifyQuery("concat", "", ", 'there', 'all', '!'", + "hellothereall!", "worldthereall!", "helloworldthereall!"); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java index 175cafd31e3..94677354e4f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java @@ -108,6 +108,7 @@ public void testLtrim() throws IOException { @Test public void testConcat() throws IOException { + verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!"); verifyQuery("concat('hello', 'world')", "keyword", "helloworld"); verifyQuery("concat('', 'hello')", "keyword", "hello"); } diff --git a/integ-test/src/test/resources/correctness/expressions/text_functions.txt b/integ-test/src/test/resources/correctness/expressions/text_functions.txt index c2fd57c330b..077cc82084b 100644 --- a/integ-test/src/test/resources/correctness/expressions/text_functions.txt +++ b/integ-test/src/test/resources/correctness/expressions/text_functions.txt @@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column LOCATE('world', 'hello') as column LOCATE('world', 'helloworld', 7) as column REPLACE('helloworld', 'world', 'opensearch') as column -REPLACE('hello', 'world', 'opensearch') as column \ No newline at end of file +REPLACE('hello', 'world', 'opensearch') as column +CONCAT('hello', 'world') as column +CONCAT('hello ', 'whole ', 'world', '!') as column \ No newline at end of file From 4ac26a27979bb0fda3e91669819747236c059536 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Fri, 13 Jan 2023 12:13:58 -0800 Subject: [PATCH 2/9] Add test case for null arg Signed-off-by: Margarit Hakobyan --- .../sql/expression/text/TextFunctionTest.java | 1 + docs/user/dql/functions.rst | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index a0fd3b8c3cd..54d2e5c400b 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -237,6 +237,7 @@ void concat() { when(missingRef.type()).thenReturn(STRING); assertEquals(missingValue(), eval( DSL.concat(missingRef, DSL.literal("1")))); + // If any of the expressions is a NULL value, it returns NULL. assertEquals(nullValue(), eval( DSL.concat(nullRef, DSL.literal("1")))); assertEquals(missingValue(), eval( diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 2ce44d91e53..69e20b07259 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2614,7 +2614,7 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. +Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. If any of the expressions is a NULL value, it returns NULL. Argument type: STRING, STRING, ...., STRING @@ -2622,13 +2622,13 @@ Return type: STRING Example:: - os> SELECT CONCAT('hello', 'world'), CONCAT('hello ', 'whole ', 'world', '!'); + os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null) fetched rows / total rows = 1/1 - +----------------------------+--------------------------------------------+ - | CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') | - |----------------------------+--------------------------------------------| - | helloworld | hello whole world! | - +----------------------------+--------------------------------------------+ + +--------------------------------------------+----------------------------+-------------------------+ + | CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) | + |--------------------------------------------+----------------------------+-------------------------| + | hello whole world! | helloworld | null | + +--------------------------------------------+----------------------------+-------------------------+ CONCAT_WS From bed81bfefc360c00ed64e9609ceb3f02f2513fc9 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Fri, 13 Jan 2023 14:07:31 -0800 Subject: [PATCH 3/9] Address PR review feedback Signed-off-by: Margarit Hakobyan --- .../sql/expression/function/BuiltinFunctionRepository.java | 6 +----- .../org/opensearch/sql/expression/function/FunctionDSL.java | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index f3a5590b98e..aaf587af82a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -182,12 +182,8 @@ private FunctionBuilder getFunctionBuilder( if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { return funcBuilder; } - // For functions with variable number of args (ex: concat()) - // targetTypes will always be empty (as the function signature is not fixed), - // and failure will occur. - // So, in this case sourceTypes are passed instead of targetTypes to address that. if (functionResolverMap.get(functionName) instanceof VarargsFunctionResolver) { - return castArguments(sourceTypes, sourceTypes, funcBuilder); + return funcBuilder; } return castArguments(sourceTypes, targetTypes, funcBuilder); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index c2814e9f4e2..5301eb2dd5c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -416,11 +416,11 @@ public SerializableVarargsFunction nullMissingHandling( return (args) -> { if (Arrays.stream(args).anyMatch(ExprValue::isMissing)) { return ExprValueUtils.missingValue(); - } else if (Arrays.stream(args).anyMatch(ExprValue::isNull)) { + } + if (Arrays.stream(args).anyMatch(ExprValue::isNull)) { return ExprValueUtils.nullValue(); - } else { - return function.apply(args); } + return function.apply(args); }; } From 30716a395f8c020cba34117dd6f0c19d495aee5e Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 17 Jan 2023 11:44:02 -0800 Subject: [PATCH 4/9] Removed VarargsFunctionResolver Signed-off-by: Margarit Hakobyan --- .../function/BuiltinFunctionRepository.java | 2 +- .../function/DefaultFunctionResolver.java | 8 +- .../sql/expression/function/FunctionDSL.java | 33 -------- .../function/VarargsFunctionResolver.java | 69 ---------------- .../sql/expression/text/TextFunction.java | 6 +- .../BuiltinFunctionRepositoryTest.java | 31 ------- .../function/DefaultFunctionResolverTest.java | 32 ++++++++ .../function/VarargsFunctionResolverTest.java | 81 ------------------- 8 files changed, 42 insertions(+), 220 deletions(-) delete mode 100644 core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java delete mode 100644 core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index aaf587af82a..8ddcfea5580 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -182,7 +182,7 @@ private FunctionBuilder getFunctionBuilder( if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { return funcBuilder; } - if (functionResolverMap.get(functionName) instanceof VarargsFunctionResolver) { + if (functionName.equals(BuiltinFunctionName.CONCAT.getName())) { return funcBuilder; } return castArguments(sourceTypes, diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index 7081179162a..c2475fa8144 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -50,7 +50,8 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { + if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) + && !isConcatFunction(unresolvedSignature)) { throw new ExpressionEvaluationException( String.format("%s function expected %s, but get %s", functionName, formatFunctions(functionBundle.keySet()), @@ -66,4 +67,9 @@ private String formatFunctions(Set functionSignatures) { return functionSignatures.stream().map(FunctionSignature::formatTypes) .collect(Collectors.joining(",", "{", "}")); } + + private boolean isConcatFunction(FunctionSignature signature) { + return signature.getFunctionName().equals(BuiltinFunctionName.CONCAT.getName()) + && !signature.getParamTypeList().isEmpty(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index 5301eb2dd5c..25eeef6fc83 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -60,39 +60,6 @@ public static DefaultFunctionResolver define(FunctionName functionName, List< return builder.build(); } - /** - * Define varargs function with implementation. - * - * @param functionName function name. - * @param functions a list of function implementation. - * @return VarargsFunctionResolver. - */ - public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, - SerializableFunction>... functions) { - return defineVarargsFunction(functionName, List.of(functions)); - } - - /** - * Define varargs function with implementation. - * - * @param functionName function name. - * @param functions a list of function implementation. - * @return VarargsFunctionResolver. - */ - public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, List< - SerializableFunction>> functions) { - - VarargsFunctionResolver.VarargsFunctionResolverBuilder builder = - VarargsFunctionResolver.builder(); - builder.functionName(functionName); - for (SerializableFunction> func - : functions) { - Pair functionBuilder = func.apply(functionName); - builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue()); - } - return builder.build(); - } /** * Implementation of no args function that uses FunctionProperties. diff --git a/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java deleted file mode 100644 index 5af054fb77e..00000000000 --- a/core/src/main/java/org/opensearch/sql/expression/function/VarargsFunctionResolver.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.function; - -import java.util.AbstractMap; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.stream.Collectors; -import lombok.Builder; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.Singular; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.exception.ExpressionEvaluationException; - -/** - * The Function Resolver hold the overload {@link FunctionBuilder} implementation. - * is composed by {@link FunctionName} which identified the function name - * and a map of {@link FunctionSignature} and {@link FunctionBuilder} - * to represent the overloaded implementation - */ -@Builder -@RequiredArgsConstructor -public class VarargsFunctionResolver implements FunctionResolver { - @Getter - private final FunctionName functionName; - @Singular("functionBundle") - private final Map functionBundle; - - /** - * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. - * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. - * If applying the widening rule, found the most match one, return it. - * If nothing found, throw {@link ExpressionEvaluationException} - * - * @return function signature and its builder - */ - @Override - public Pair resolve(FunctionSignature unresolvedSignature) { - PriorityQueue> functionMatchQueue = new PriorityQueue<>( - Map.Entry.comparingByKey()); - - for (FunctionSignature functionSignature : functionBundle.keySet()) { - functionMatchQueue.add( - new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), - functionSignature)); - } - Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (unresolvedSignature.getParamTypeList().isEmpty()) { - throw new ExpressionEvaluationException( - String.format("%s function expected %s, but get %s", functionName, - formatFunctions(functionBundle.keySet()), - unresolvedSignature.formatTypes() - )); - } else { - FunctionSignature resolvedSignature = bestMatchEntry.getValue(); - return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); - } - } - - private String formatFunctions(Set functionSignatures) { - return functionSignatures.stream().map(FunctionSignature::formatTypes) - .collect(Collectors.joining(",", "{", "}")); - } -} diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index e57c696785c..e1311a54302 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.function.FunctionDSL.define; -import static org.opensearch.sql.expression.function.FunctionDSL.defineVarargsFunction; import static org.opensearch.sql.expression.function.FunctionDSL.impl; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; @@ -25,7 +24,6 @@ import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; -import org.opensearch.sql.expression.function.VarargsFunctionResolver; /** * The definition of text functions. @@ -148,8 +146,8 @@ private DefaultFunctionResolver upper() { * Supports following signatures: * (STRING, STRING, ...., STRING) -> STRING */ - private VarargsFunctionResolver concat() { - return defineVarargsFunction(BuiltinFunctionName.CONCAT.getName(), + private DefaultFunctionResolver concat() { + return define(BuiltinFunctionName.CONCAT.getName(), impl(nullMissingHandling(strings -> new ExprStringValue(Arrays.stream(strings) .map(ExprValue::stringValue) diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index 28fbb7bd974..8bba3bd9b95 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -238,37 +238,6 @@ void resolve_unregistered() { assertEquals("unsupported function name: unknown", exception.getMessage()); } - @Test - void resolve_should_cast_arguments_for_varargs_function() { - FunctionSignature unresolvedSignature = new FunctionSignature( - mockFunctionName, ImmutableList.of(STRING, STRING, STRING)); - FunctionSignature resolvedSignature = new FunctionSignature( - mockFunctionName, Collections.emptyList()); - - VarargsFunctionResolver varargsFunctionResolver = mock(VarargsFunctionResolver.class); - FunctionBuilder funcBuilder = mock(FunctionBuilder.class); - - when(mockFunctionName.getFunctionName()).thenReturn("mockFunction"); - when(mockExpression.toString()).thenReturn("string"); - when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap); - when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); - when(mockMap.containsKey(eq(mockFunctionName))).thenReturn(true); - when(mockMap.get(eq(mockFunctionName))).thenReturn(varargsFunctionResolver); - when(varargsFunctionResolver.resolve(eq(unresolvedSignature))).thenReturn( - Pair.of(resolvedSignature, funcBuilder)); - repo.register(varargsFunctionResolver); - // Relax unnecessary stubbing check because error case test doesn't call this - lenient().doAnswer(invocation -> - new FakeFunctionExpression(mockFunctionName, invocation.getArgument(1)) - ).when(funcBuilder).apply(eq(functionProperties), any()); - - FunctionImplementation function = - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), unresolvedSignature) - .apply(functionProperties, - ImmutableList.of(mockExpression, mockExpression, mockExpression)); - assertEquals("mockFunction(string, string, string)", function.toString()); - } - private FunctionSignature registerFunctionResolver(FunctionName funcName, ExprType sourceType, ExprType targetType) { diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index baa299b60be..c413878c58d 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -9,8 +9,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -68,6 +71,7 @@ void resolve_function_not_match() { when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(notMatchFS.formatTypes()).thenReturn("[INTEGER,INTEGER]"); when(functionSignature.formatTypes()).thenReturn("[BOOLEAN,BOOLEAN]"); + when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(notMatchFS, notMatchBuilder)); @@ -76,4 +80,32 @@ void resolve_function_not_match() { assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]", exception.getMessage()); } + + @Test + void resolve_concat_function_signature_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); + when(functionSignature.getFunctionName()).thenReturn(functionName); + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING)); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(notMatchFS, notMatchBuilder)); + + assertEquals(notMatchBuilder, resolver.resolve(functionSignature).getValue()); + } + + @Test + void resolve_concat_function_signature_not_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); + when(functionSignature.getFunctionName()).thenReturn(functionName); + // Concat function with no arguments + when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList()); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(notMatchFS, notMatchBuilder)); + + assertThrows(ExpressionEvaluationException.class, + () -> resolver.resolve(functionSignature)); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java deleted file mode 100644 index dafec21f402..00000000000 --- a/core/src/test/java/org/opensearch/sql/expression/function/VarargsFunctionResolverTest.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.expression.function; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import java.util.Collections; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.data.type.WideningTypeRule; -import org.opensearch.sql.exception.ExpressionEvaluationException; - -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -@ExtendWith(MockitoExtension.class) -class VarargsFunctionResolverTest { - @Mock - private FunctionSignature exactlyMatchFS; - @Mock - private FunctionSignature bestMatchFS; - @Mock - private FunctionSignature leastMatchFS; - @Mock - private FunctionSignature notMatchFS; - @Mock - private FunctionSignature functionSignature; - @Mock - private FunctionBuilder exactlyMatchBuilder; - @Mock - private FunctionBuilder bestMatchBuilder; - @Mock - private FunctionBuilder leastMatchBuilder; - @Mock - private FunctionBuilder notMatchBuilder; - - private FunctionName functionName = FunctionName.of("test_function"); - - @Test - void resolve_function_signature_exactly_match() { - when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); - when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING, STRING, STRING)); - VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, - ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); - - assertEquals(exactlyMatchBuilder, resolver.resolve(functionSignature).getValue()); - } - - @Test - void resolve_function_signature_best_match() { - when(functionSignature.match(bestMatchFS)).thenReturn(1); - when(functionSignature.match(leastMatchFS)).thenReturn(2); - when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING, STRING, STRING)); - VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, - ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); - - assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); - } - - @Test - void resolve_function_not_match() { - when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); - // accepts 1 or more arguments - when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList()); - VarargsFunctionResolver resolver = new VarargsFunctionResolver(functionName, - ImmutableMap.of(notMatchFS, notMatchBuilder)); - - assertThrows(ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); - } -} From 201f90af4416f5bdd0be7482cc24954c78226d71 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 17 Jan 2023 12:34:59 -0800 Subject: [PATCH 5/9] Function expects 1-9 arguments Signed-off-by: Margarit Hakobyan --- .../function/DefaultFunctionResolver.java | 10 +++++-- .../function/DefaultFunctionResolverTest.java | 27 +++++++++++++++++-- docs/user/dql/functions.rst | 2 +- docs/user/ppl/functions/string.rst | 2 +- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index c2475fa8144..7201d4ba84d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -50,6 +50,13 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); + if (isConcatFunction(unresolvedSignature) + && (unresolvedSignature.getParamTypeList().isEmpty() + || unresolvedSignature.getParamTypeList().size() > 9)) { + throw new ExpressionEvaluationException( + String.format("%s function expected 1-9 arguments, but got %s", + functionName, unresolvedSignature.getParamTypeList().size())); + } if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) && !isConcatFunction(unresolvedSignature)) { throw new ExpressionEvaluationException( @@ -69,7 +76,6 @@ private String formatFunctions(Set functionSignatures) { } private boolean isConcatFunction(FunctionSignature signature) { - return signature.getFunctionName().equals(BuiltinFunctionName.CONCAT.getName()) - && !signature.getParamTypeList().isEmpty(); + return signature.getFunctionName().equals(BuiltinFunctionName.CONCAT.getName()); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index c413878c58d..579a2499e56 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -50,6 +50,7 @@ class DefaultFunctionResolverTest { @Test void resolve_function_signature_exactly_match() { when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); @@ -60,6 +61,7 @@ void resolve_function_signature_exactly_match() { void resolve_function_signature_best_match() { when(functionSignature.match(bestMatchFS)).thenReturn(1); when(functionSignature.match(leastMatchFS)).thenReturn(2); + when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); @@ -95,7 +97,7 @@ void resolve_concat_function_signature_match() { } @Test - void resolve_concat_function_signature_not_match() { + void resolve_concat_no_args_function_signature_not_match() { functionName = FunctionName.of("concat"); when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(functionSignature.getFunctionName()).thenReturn(functionName); @@ -105,7 +107,28 @@ void resolve_concat_function_signature_not_match() { DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(notMatchFS, notMatchBuilder)); - assertThrows(ExpressionEvaluationException.class, + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> resolver.resolve(functionSignature)); + assertEquals("concat function expected 1-9 arguments, but got 0", + exception.getMessage()); + } + + @Test + void resolve_concat_too_many_args_function_signature_not_match() { + functionName = FunctionName.of("concat"); + when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); + when(functionSignature.getFunctionName()).thenReturn(functionName); + // Concat function with more than 9 arguments + when(functionSignature.getParamTypeList()).thenReturn(ImmutableList + .of(STRING, STRING, STRING, STRING, STRING, + STRING, STRING, STRING, STRING, STRING)); + + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, + ImmutableMap.of(notMatchFS, notMatchBuilder)); + + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); + assertEquals("concat function expected 1-9 arguments, but got 10", + exception.getMessage()); } } diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 69e20b07259..23547c02164 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2614,7 +2614,7 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. If any of the expressions is a NULL value, it returns NULL. +Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. Expects 1-9 arguments. If any of the expressions is a NULL value, it returns NULL. Argument type: STRING, STRING, ...., STRING diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index 315a46616ca..c55dd654a5a 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -14,7 +14,7 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. +Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. Expects 1-9 arguments. Argument type: STRING, STRING, ...., STRING From 2eee626584f5952178e8714e501e55d139011a11 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 24 Jan 2023 10:55:24 -0800 Subject: [PATCH 6/9] Add varargs functions map for better extensibility Signed-off-by: Margarit Hakobyan --- .../sql/expression/function/BuiltinFunctionName.java | 5 +++++ .../expression/function/BuiltinFunctionRepository.java | 2 +- .../sql/expression/function/DefaultFunctionResolver.java | 9 +++++---- .../expression/function/DefaultFunctionResolverTest.java | 6 +++--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f9d38a0da39..92f86bb03b3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -258,6 +258,11 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } + public static final Map VARARGS_FUNCTIONS_MAP = + new ImmutableMap.Builder() + .put("concat", BuiltinFunctionName.CONCAT.name) + .build(); + private static final Map AGGREGATION_FUNC_MAPPING = new ImmutableMap.Builder() .put("max", BuiltinFunctionName.MAX) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 8ddcfea5580..adfbbc9a7de 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -182,7 +182,7 @@ private FunctionBuilder getFunctionBuilder( if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { return funcBuilder; } - if (functionName.equals(BuiltinFunctionName.CONCAT.getName())) { + if (BuiltinFunctionName.VARARGS_FUNCTIONS_MAP.containsValue(functionName)) { return funcBuilder; } return castArguments(sourceTypes, diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index 7201d4ba84d..4ceb3ad4c73 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -50,7 +50,7 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (isConcatFunction(unresolvedSignature) + if (isVarargsFunction(unresolvedSignature) && (unresolvedSignature.getParamTypeList().isEmpty() || unresolvedSignature.getParamTypeList().size() > 9)) { throw new ExpressionEvaluationException( @@ -58,7 +58,7 @@ public Pair resolve(FunctionSignature unreso functionName, unresolvedSignature.getParamTypeList().size())); } if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) - && !isConcatFunction(unresolvedSignature)) { + && !isVarargsFunction(unresolvedSignature)) { throw new ExpressionEvaluationException( String.format("%s function expected %s, but get %s", functionName, formatFunctions(functionBundle.keySet()), @@ -75,7 +75,8 @@ private String formatFunctions(Set functionSignatures) { .collect(Collectors.joining(",", "{", "}")); } - private boolean isConcatFunction(FunctionSignature signature) { - return signature.getFunctionName().equals(BuiltinFunctionName.CONCAT.getName()); + private boolean isVarargsFunction(FunctionSignature signature) { + return BuiltinFunctionName.VARARGS_FUNCTIONS_MAP + .containsValue(signature.getFunctionName()); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index 579a2499e56..f13a30c09d2 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -84,7 +84,7 @@ void resolve_function_not_match() { } @Test - void resolve_concat_function_signature_match() { + void resolve_varargs_function_signature_match() { functionName = FunctionName.of("concat"); when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(functionSignature.getFunctionName()).thenReturn(functionName); @@ -97,7 +97,7 @@ void resolve_concat_function_signature_match() { } @Test - void resolve_concat_no_args_function_signature_not_match() { + void resolve_varargs_no_args_function_signature_not_match() { functionName = FunctionName.of("concat"); when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(functionSignature.getFunctionName()).thenReturn(functionName); @@ -114,7 +114,7 @@ void resolve_concat_no_args_function_signature_not_match() { } @Test - void resolve_concat_too_many_args_function_signature_not_match() { + void resolve_varargs_too_many_args_function_signature_not_match() { functionName = FunctionName.of("concat"); when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(functionSignature.getFunctionName()).thenReturn(functionName); From d066caa0fa04d5e7cc345ba42ed497e007f1d541 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 24 Jan 2023 16:03:48 -0800 Subject: [PATCH 7/9] Address PR review feedback Signed-off-by: Margarit Hakobyan --- .../function/BuiltinFunctionName.java | 4 -- .../function/BuiltinFunctionRepository.java | 12 ++-- .../function/DefaultFunctionResolver.java | 15 ++-- .../sql/expression/function/FunctionDSL.java | 68 +------------------ .../function/FunctionSignature.java | 6 ++ .../function/SerializableVarargsFunction.java | 22 ------ .../sql/expression/text/TextFunction.java | 42 ++++++++++-- .../function/DefaultFunctionResolverTest.java | 24 +++---- .../function/FunctionDSLTestBase.java | 2 - .../function/FunctionDSLimplVarargsTest.java | 32 --------- 10 files changed, 71 insertions(+), 156 deletions(-) delete mode 100644 core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java delete mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 92f86bb03b3..e08e42800be 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -258,10 +258,6 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } - public static final Map VARARGS_FUNCTIONS_MAP = - new ImmutableMap.Builder() - .put("concat", BuiltinFunctionName.CONCAT.name) - .build(); private static final Map AGGREGATION_FUNC_MAPPING = new ImmutableMap.Builder() diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index adfbbc9a7de..085d86babf7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.ast.expression.Cast.getCastFunctionName; import static org.opensearch.sql.ast.expression.Cast.isCastFunction; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; @@ -179,10 +180,9 @@ private FunctionBuilder getFunctionBuilder( List sourceTypes = functionSignature.getParamTypeList(); List targetTypes = resolvedSignature.getKey().getParamTypeList(); FunctionBuilder funcBuilder = resolvedSignature.getValue(); - if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { - return funcBuilder; - } - if (BuiltinFunctionName.VARARGS_FUNCTIONS_MAP.containsValue(functionName)) { + if (isCastFunction(functionName) + || isVarArgFunction(targetTypes) + || sourceTypes.equals(targetTypes)) { return funcBuilder; } return castArguments(sourceTypes, @@ -233,4 +233,8 @@ private Function cast(Expression arg, ExprType t return functionProperties -> (Expression) compile(functionProperties, castFunctionName, List.of(arg)); } + + private boolean isVarArgFunction(List argTypes) { + return argTypes.size() == 1 && argTypes.get(0) == ARRAY; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index 4ceb3ad4c73..87b7a555773 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -5,7 +5,10 @@ package org.opensearch.sql.expression.function; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; + import java.util.AbstractMap; +import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Set; @@ -15,6 +18,7 @@ import lombok.RequiredArgsConstructor; import lombok.Singular; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.ExpressionEvaluationException; /** @@ -50,15 +54,15 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (isVarargsFunction(unresolvedSignature) + if (isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) && (unresolvedSignature.getParamTypeList().isEmpty() - || unresolvedSignature.getParamTypeList().size() > 9)) { + || unresolvedSignature.getParamTypeList().size() > 9)) { throw new ExpressionEvaluationException( String.format("%s function expected 1-9 arguments, but got %s", functionName, unresolvedSignature.getParamTypeList().size())); } if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) - && !isVarargsFunction(unresolvedSignature)) { + && !isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) { throw new ExpressionEvaluationException( String.format("%s function expected %s, but get %s", functionName, formatFunctions(functionBundle.keySet()), @@ -75,8 +79,7 @@ private String formatFunctions(Set functionSignatures) { .collect(Collectors.joining(",", "{", "}")); } - private boolean isVarargsFunction(FunctionSignature signature) { - return BuiltinFunctionName.VARARGS_FUNCTIONS_MAP - .containsValue(signature.getFunctionName()); + private boolean isVarArgFunction(List argTypes) { + return argTypes.size() == 1 && argTypes.get(0) == ARRAY; } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index 25eeef6fc83..f6666309b57 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -214,56 +214,6 @@ public static SerializableFunction function.apply(arg), returnType, argsType); } - /** - * Varargs Function Implementation. - * This implementation considers 1...n args of the same type. - * - * @param function {@link ExprValue} based varargs function. - * @param returnType return type. - * @param argsType argument type. - * @return Varargs Function Implementation. - */ - public static SerializableFunction> impl( - SerializableVarargsFunction function, - ExprType returnType, - ExprType argsType, - boolean withVarargs) { - - return functionName -> { - AtomicInteger argsCount = new AtomicInteger(0); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - argsCount.set(arguments.size()); - ExprValue[] args = arguments.stream() - .map(arg -> arg.valueOf(valueEnv)) - .collect(Collectors.toList()) - .toArray(new ExprValue[arguments.size()]); - - return function.apply(args); - } - - @Override - public ExprType type() { - return returnType; - } - - @Override - public String toString() { - return String.format("%s(%s)", functionName, arguments.stream() - .map(Object::toString) - .collect(Collectors.joining(", "))); - } - }; - ExprCoreType[] argsTypes = new ExprCoreType[argsCount.get()]; - Arrays.fill(argsTypes, argsType); - FunctionSignature functionSignature = - new FunctionSignature(functionName, List.of(argsTypes)); - return Pair.of(functionSignature, functionBuilder); - }; - } - /** * Binary Function Implementation. * @@ -375,29 +325,13 @@ public SerializableTriFunction nullM }; } - /** - * Wrapper the varargs ExprValue function with default NULL and MISSING handling. - */ - public SerializableVarargsFunction nullMissingHandling( - SerializableVarargsFunction function, boolean withVarargs) { - return (args) -> { - if (Arrays.stream(args).anyMatch(ExprValue::isMissing)) { - return ExprValueUtils.missingValue(); - } - if (Arrays.stream(args).anyMatch(ExprValue::isNull)) { - return ExprValueUtils.nullValue(); - } - return function.apply(args); - }; - } - /** * Wrapper the unary ExprValue function that is aware of FunctionProperties, * with default NULL and MISSING handling. */ public static SerializableBiFunction nullMissingHandlingWithProperties( - SerializableBiFunction implementation) { + SerializableBiFunction implementation) { return (functionProperties, v1) -> { if (v1.isMissing()) { return ExprValueUtils.missingValue(); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java index adb16983866..7866c9112a8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java @@ -5,6 +5,8 @@ package org.opensearch.sql.expression.function; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; + import java.util.List; import java.util.stream.Collectors; import lombok.EqualsAndHashCode; @@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) { || paramTypeList.size() != functionTypeList.size()) { return NOT_MATCH; } + // TODO: improve to support regular and array type mixed, ex. func(int,string,array) + if (functionTypeList.size() == 1 && functionTypeList.get(0) == ARRAY) { + return EXACTLY_MATCH; + } int matchDegree = EXACTLY_MATCH; for (int i = 0; i < paramTypeList.size(); i++) { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java deleted file mode 100644 index 3c0c07fa79f..00000000000 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableVarargsFunction.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.expression.function; - -import java.io.Serializable; - -/** - * Serializable Varargs Function. - */ -public interface SerializableVarargsFunction extends Serializable { - /** - * Applies this function to the given arguments. - * - * @param t the function argument - * @return the function result - */ - R apply(T... t); -} diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index e1311a54302..5b670affd49 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -6,22 +6,31 @@ package org.opensearch.sql.expression.text; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.function.FunctionDSL.define; import static org.opensearch.sql.expression.function.FunctionDSL.impl; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; -import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; @@ -147,11 +156,32 @@ private DefaultFunctionResolver upper() { * (STRING, STRING, ...., STRING) -> STRING */ private DefaultFunctionResolver concat() { - return define(BuiltinFunctionName.CONCAT.getName(), - impl(nullMissingHandling(strings -> - new ExprStringValue(Arrays.stream(strings) - .map(ExprValue::stringValue) - .collect(Collectors.joining())), true), STRING, STRING, true)); + FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName(); + return define(concatFuncName, funcName -> + Pair.of( + new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)), + (funcProp, args) -> new FunctionExpression(funcName, args) { + @Override + public ExprValue valueOf(Environment valueEnv) { + List exprValues = args.stream() + .map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList()); + if (exprValues.stream().anyMatch(ExprValue::isMissing)) { + return ExprValueUtils.missingValue(); + } + if (exprValues.stream().anyMatch(ExprValue::isNull)) { + return ExprValueUtils.nullValue(); + } + return new ExprStringValue(exprValues.stream() + .map(argVal -> String.valueOf(argVal.value())) + .collect(Collectors.joining())); + } + + @Override + public ExprType type() { + return STRING; + } + } + )); } /** diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index f13a30c09d2..202c1bd0aa9 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import com.google.common.collect.ImmutableList; @@ -50,7 +51,6 @@ class DefaultFunctionResolverTest { @Test void resolve_function_signature_exactly_match() { when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); - when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); @@ -61,7 +61,6 @@ void resolve_function_signature_exactly_match() { void resolve_function_signature_best_match() { when(functionSignature.match(bestMatchFS)).thenReturn(1); when(functionSignature.match(leastMatchFS)).thenReturn(2); - when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); @@ -73,7 +72,6 @@ void resolve_function_not_match() { when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(notMatchFS.formatTypes()).thenReturn("[INTEGER,INTEGER]"); when(functionSignature.formatTypes()).thenReturn("[BOOLEAN,BOOLEAN]"); - when(functionSignature.getFunctionName()).thenReturn(functionName); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(notMatchFS, notMatchBuilder)); @@ -86,26 +84,26 @@ void resolve_function_not_match() { @Test void resolve_varargs_function_signature_match() { functionName = FunctionName.of("concat"); - when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); - when(functionSignature.getFunctionName()).thenReturn(functionName); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING)); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, - ImmutableMap.of(notMatchFS, notMatchBuilder)); + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); - assertEquals(notMatchBuilder, resolver.resolve(functionSignature).getValue()); + assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); } @Test void resolve_varargs_no_args_function_signature_not_match() { functionName = FunctionName.of("concat"); - when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); - when(functionSignature.getFunctionName()).thenReturn(functionName); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); // Concat function with no arguments when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList()); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, - ImmutableMap.of(notMatchFS, notMatchBuilder)); + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); @@ -116,15 +114,15 @@ void resolve_varargs_no_args_function_signature_not_match() { @Test void resolve_varargs_too_many_args_function_signature_not_match() { functionName = FunctionName.of("concat"); - when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); - when(functionSignature.getFunctionName()).thenReturn(functionName); + when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); + when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY)); // Concat function with more than 9 arguments when(functionSignature.getParamTypeList()).thenReturn(ImmutableList .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING)); DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, - ImmutableMap.of(notMatchFS, notMatchBuilder)); + ImmutableMap.of(bestMatchFS, bestMatchBuilder)); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java index 5bcc9f9e89f..63c6ea33296 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java @@ -59,8 +59,6 @@ public int compareTo(ExprValue o) { twoArgs = (v1, v2) -> ANY; static final SerializableTriFunction threeArgs = (v1, v2, v3) -> ANY; - static final SerializableVarargsFunction - varrgs = (v1) -> ANY; @Mock FunctionProperties mockProperties; } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java deleted file mode 100644 index cee8889359b..00000000000 --- a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplVarargsTest.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.function; - -import static org.opensearch.sql.expression.function.FunctionDSL.impl; - -import java.util.List; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; - -class FunctionDSLimplVarargsTest extends FunctionDSLimplTestBase { - - @Override - SerializableFunction> - getImplementationGenerator() { - return impl(varrgs, ANY_TYPE, ANY_TYPE, true); - } - - @Override - List getSampleArguments() { - return List.of(DSL.literal(ANY)); - } - - @Override - String getExpected_toString() { - return "sample(ANY)"; - } -} From 4e2422e5795ee7f3ebf6adef78adfbcb46ead165 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 24 Jan 2023 16:05:50 -0800 Subject: [PATCH 8/9] Minor cleanup Signed-off-by: Margarit Hakobyan --- .../opensearch/sql/expression/function/BuiltinFunctionName.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index e08e42800be..f9d38a0da39 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -258,7 +258,6 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } - private static final Map AGGREGATION_FUNC_MAPPING = new ImmutableMap.Builder() .put("max", BuiltinFunctionName.MAX) From 0654fbb534c47b79b65856bbd58929814de63733 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Tue, 24 Jan 2023 21:34:15 -0800 Subject: [PATCH 9/9] Address more PR review feedback Signed-off-by: Margarit Hakobyan --- .../function/BuiltinFunctionRepository.java | 7 +------ .../function/DefaultFunctionResolver.java | 14 +++----------- .../sql/expression/function/FunctionDSL.java | 2 -- .../sql/expression/function/FunctionSignature.java | 9 ++++++++- .../sql/expression/text/TextFunction.java | 4 ++-- docs/user/dql/functions.rst | 2 +- docs/user/ppl/functions/string.rst | 2 +- 7 files changed, 16 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 085d86babf7..20f56a21cb7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.ast.expression.Cast.getCastFunctionName; import static org.opensearch.sql.ast.expression.Cast.isCastFunction; -import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; @@ -181,7 +180,7 @@ private FunctionBuilder getFunctionBuilder( List targetTypes = resolvedSignature.getKey().getParamTypeList(); FunctionBuilder funcBuilder = resolvedSignature.getValue(); if (isCastFunction(functionName) - || isVarArgFunction(targetTypes) + || FunctionSignature.isVarArgFunction(targetTypes) || sourceTypes.equals(targetTypes)) { return funcBuilder; } @@ -233,8 +232,4 @@ private Function cast(Expression arg, ExprType t return functionProperties -> (Expression) compile(functionProperties, castFunctionName, List.of(arg)); } - - private boolean isVarArgFunction(List argTypes) { - return argTypes.size() == 1 && argTypes.get(0) == ARRAY; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index 87b7a555773..a28fa7e0ada 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -5,10 +5,7 @@ package org.opensearch.sql.expression.function; -import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; - import java.util.AbstractMap; -import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Set; @@ -18,7 +15,6 @@ import lombok.RequiredArgsConstructor; import lombok.Singular; import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.ExpressionEvaluationException; /** @@ -54,15 +50,15 @@ public Pair resolve(FunctionSignature unreso functionSignature)); } Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) + if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) && (unresolvedSignature.getParamTypeList().isEmpty() || unresolvedSignature.getParamTypeList().size() > 9)) { throw new ExpressionEvaluationException( - String.format("%s function expected 1-9 arguments, but got %s", + String.format("%s function expected 1-9 arguments, but got %d", functionName, unresolvedSignature.getParamTypeList().size())); } if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) - && !isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) { + && !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) { throw new ExpressionEvaluationException( String.format("%s function expected %s, but get %s", functionName, formatFunctions(functionBundle.keySet()), @@ -78,8 +74,4 @@ private String formatFunctions(Set functionSignatures) { return functionSignatures.stream().map(FunctionSignature::formatTypes) .collect(Collectors.joining(",", "{", "}")); } - - private boolean isVarArgFunction(List argTypes) { - return argTypes.size() == 1 && argTypes.get(0) == ARRAY; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index f6666309b57..d94d7cdf601 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -9,14 +9,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java index 7866c9112a8..0c59d71c256 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java @@ -42,7 +42,7 @@ public int match(FunctionSignature functionSignature) { return NOT_MATCH; } // TODO: improve to support regular and array type mixed, ex. func(int,string,array) - if (functionTypeList.size() == 1 && functionTypeList.get(0) == ARRAY) { + if (isVarArgFunction(functionTypeList)) { return EXACTLY_MATCH; } @@ -68,4 +68,11 @@ public String formatTypes() { .map(ExprType::typeName) .collect(Collectors.joining(",", "[", "]")); } + + /** + * util function - returns true if function has variable arguments. + */ + protected static boolean isVarArgFunction(List argTypes) { + return argTypes.size() == 1 && argTypes.get(0) == ARRAY; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 5b670affd49..e56c85a0c8e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -172,8 +172,8 @@ public ExprValue valueOf(Environment valueEnv) { return ExprValueUtils.nullValue(); } return new ExprStringValue(exprValues.stream() - .map(argVal -> String.valueOf(argVal.value())) - .collect(Collectors.joining())); + .map(ExprValue::stringValue) + .collect(Collectors.joining())); } @Override diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 23547c02164..ab96075ac3e 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2614,7 +2614,7 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. Expects 1-9 arguments. If any of the expressions is a NULL value, it returns NULL. +Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL. Argument type: STRING, STRING, ...., STRING diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index c55dd654a5a..9b7e69d9850 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -14,7 +14,7 @@ CONCAT Description >>>>>>>>>>> -Usage: CONCAT(str1, str2, ...., str_n) adds two or more strings together. Expects 1-9 arguments. +Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. Argument type: STRING, STRING, ...., STRING