Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ private FunctionBuilder getFunctionBuilder(
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
// For functions with variable number of args (ex: concat())
// targetTypes will always be empty (as the function signature is not fixed),
// and failure will occur.
// So, in this case sourceTypes are passed instead of targetTypes to address that.
if (functionResolverMap.get(functionName) instanceof VarargsFunctionResolver) {
return castArguments(sourceTypes, sourceTypes, funcBuilder);
}
return castArguments(sourceTypes,
targetTypes, funcBuilder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
Expand Down Expand Up @@ -58,6 +60,39 @@ public static DefaultFunctionResolver define(FunctionName functionName, List<
return builder.build();
}

/**
* Define varargs function with implementation.
*
* @param functionName function name.
* @param functions a list of function implementation.
* @return VarargsFunctionResolver.
*/
public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName,
SerializableFunction<FunctionName,
Pair<FunctionSignature, FunctionBuilder>>... functions) {
return defineVarargsFunction(functionName, List.of(functions));
}

/**
* Define varargs function with implementation.
*
* @param functionName function name.
* @param functions a list of function implementation.
* @return VarargsFunctionResolver.
*/
public static VarargsFunctionResolver defineVarargsFunction(FunctionName functionName, List<
SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>>> functions) {

VarargsFunctionResolver.VarargsFunctionResolverBuilder builder =
VarargsFunctionResolver.builder();
builder.functionName(functionName);
for (SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>> func
: functions) {
Pair<FunctionSignature, FunctionBuilder> functionBuilder = func.apply(functionName);
builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue());
}
return builder.build();
}

/**
* Implementation of no args function that uses FunctionProperties.
Expand Down Expand Up @@ -212,6 +247,56 @@ public static SerializableFunction<FunctionName, Pair<FunctionSignature, Functio
return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType);
}

/**
* Varargs Function Implementation.
* This implementation considers 1...n args of the same type.
*
* @param function {@link ExprValue} based varargs function.
* @param returnType return type.
* @param argsType argument type.
* @return Varargs Function Implementation.
*/
public static SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>> impl(
SerializableVarargsFunction<ExprValue, ExprValue> function,
ExprType returnType,
ExprType argsType,
boolean withVarargs) {

return functionName -> {
AtomicInteger argsCount = new AtomicInteger(0);
FunctionBuilder functionBuilder =
(functionProperties, arguments) -> new FunctionExpression(functionName, arguments) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
argsCount.set(arguments.size());
ExprValue[] args = arguments.stream()
.map(arg -> arg.valueOf(valueEnv))
.collect(Collectors.toList())
.toArray(new ExprValue[arguments.size()]);

return function.apply(args);
}

@Override
public ExprType type() {
return returnType;
}

@Override
public String toString() {
return String.format("%s(%s)", functionName, arguments.stream()
.map(Object::toString)
.collect(Collectors.joining(", ")));
}
};
ExprCoreType[] argsTypes = new ExprCoreType[argsCount.get()];
Arrays.fill(argsTypes, argsType);
FunctionSignature functionSignature =
new FunctionSignature(functionName, List.of(argsTypes));
return Pair.of(functionSignature, functionBuilder);
};
}

/**
* Binary Function Implementation.
*
Expand Down Expand Up @@ -323,13 +408,29 @@ public SerializableTriFunction<ExprValue, ExprValue, ExprValue, ExprValue> nullM
};
}

/**
* Wrapper the varargs ExprValue function with default NULL and MISSING handling.
*/
public SerializableVarargsFunction<ExprValue, ExprValue> nullMissingHandling(
SerializableVarargsFunction<ExprValue, ExprValue> function, boolean withVarargs) {
return (args) -> {
if (Arrays.stream(args).anyMatch(ExprValue::isMissing)) {
return ExprValueUtils.missingValue();
} else if (Arrays.stream(args).anyMatch(ExprValue::isNull)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else not needed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in bed81bf

return ExprValueUtils.nullValue();
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else not needed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in bed81bf

return function.apply(args);
}
};
}

/**
* Wrapper the unary ExprValue function that is aware of FunctionProperties,
* with default NULL and MISSING handling.
*/
public static SerializableBiFunction<FunctionProperties, ExprValue, ExprValue>
nullMissingHandlingWithProperties(
SerializableBiFunction<FunctionProperties, ExprValue, ExprValue> implementation) {
SerializableBiFunction<FunctionProperties, ExprValue, ExprValue> implementation) {
return (functionProperties, v1) -> {
if (v1.isMissing()) {
return ExprValueUtils.missingValue();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.expression.function;

import java.io.Serializable;

/**
* Serializable Varargs Function.
*/
public interface SerializableVarargsFunction<T, R> extends Serializable {
/**
* Applies this function to the given arguments.
*
* @param t the function argument
* @return the function result
*/
R apply(T... t);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import java.util.AbstractMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Singular;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.exception.ExpressionEvaluationException;

/**
* The Function Resolver hold the overload {@link FunctionBuilder} implementation.
* is composed by {@link FunctionName} which identified the function name
* and a map of {@link FunctionSignature} and {@link FunctionBuilder}
* to represent the overloaded implementation
*/
@Builder
@RequiredArgsConstructor
public class VarargsFunctionResolver implements FunctionResolver {
@Getter
private final FunctionName functionName;
@Singular("functionBundle")
private final Map<FunctionSignature, FunctionBuilder> functionBundle;

/**
* Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}.
* If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it.
* If applying the widening rule, found the most match one, return it.
* If nothing found, throw {@link ExpressionEvaluationException}
*
* @return function signature and its builder
*/
@Override
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
PriorityQueue<Map.Entry<Integer, FunctionSignature>> functionMatchQueue = new PriorityQueue<>(
Map.Entry.comparingByKey());

for (FunctionSignature functionSignature : functionBundle.keySet()) {
functionMatchQueue.add(
new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature),
functionSignature));
}
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
if (unresolvedSignature.getParamTypeList().isEmpty()) {
throw new ExpressionEvaluationException(
String.format("%s function expected %s, but get %s", functionName,
formatFunctions(functionBundle.keySet()),
unresolvedSignature.formatTypes()
));
} else {
FunctionSignature resolvedSignature = bestMatchEntry.getValue();
return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature));
}
}

private String formatFunctions(Set<FunctionSignature> functionSignatures) {
return functionSignatures.stream().map(FunctionSignature::formatTypes)
.collect(Collectors.joining(",", "{", "}"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.define;
import static org.opensearch.sql.expression.function.FunctionDSL.defineVarargsFunction;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import java.util.Arrays;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprStringValue;
Expand All @@ -22,7 +25,7 @@
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.expression.function.SerializableTriFunction;

import org.opensearch.sql.expression.function.VarargsFunctionResolver;

/**
* The definition of text functions.
Expand Down Expand Up @@ -141,16 +144,16 @@ private DefaultFunctionResolver upper() {
}

/**
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
* Extend to accept variable argument amounts.
* Concatenates a list of Strings.
* Supports following signatures:
* (STRING, STRING) -> STRING
* (STRING, STRING, ...., STRING) -> STRING
*/
private DefaultFunctionResolver concat() {
return define(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling((str1, str2) ->
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
private VarargsFunctionResolver concat() {
return defineVarargsFunction(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling(strings ->
new ExprStringValue(Arrays.stream(strings)
.map(ExprValue::stringValue)
.collect(Collectors.joining())), true), STRING, STRING, true));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,37 @@ void resolve_unregistered() {
assertEquals("unsupported function name: unknown", exception.getMessage());
}

@Test
void resolve_should_cast_arguments_for_varargs_function() {
FunctionSignature unresolvedSignature = new FunctionSignature(
mockFunctionName, ImmutableList.of(STRING, STRING, STRING));
FunctionSignature resolvedSignature = new FunctionSignature(
mockFunctionName, Collections.emptyList());

VarargsFunctionResolver varargsFunctionResolver = mock(VarargsFunctionResolver.class);
FunctionBuilder funcBuilder = mock(FunctionBuilder.class);

when(mockFunctionName.getFunctionName()).thenReturn("mockFunction");
when(mockExpression.toString()).thenReturn("string");
when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap);
when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true);
when(mockMap.containsKey(eq(mockFunctionName))).thenReturn(true);
when(mockMap.get(eq(mockFunctionName))).thenReturn(varargsFunctionResolver);
when(varargsFunctionResolver.resolve(eq(unresolvedSignature))).thenReturn(
Pair.of(resolvedSignature, funcBuilder));
repo.register(varargsFunctionResolver);
// Relax unnecessary stubbing check because error case test doesn't call this
lenient().doAnswer(invocation ->
new FakeFunctionExpression(mockFunctionName, invocation.getArgument(1))
).when(funcBuilder).apply(eq(functionProperties), any());

FunctionImplementation function =
repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), unresolvedSignature)
.apply(functionProperties,
ImmutableList.of(mockExpression, mockExpression, mockExpression));
assertEquals("mockFunction(string, string, string)", function.toString());
}

private FunctionSignature registerFunctionResolver(FunctionName funcName,
ExprType sourceType,
ExprType targetType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public int compareTo(ExprValue o) {
twoArgs = (v1, v2) -> ANY;
static final SerializableTriFunction<ExprValue, ExprValue, ExprValue, ExprValue>
threeArgs = (v1, v2, v3) -> ANY;
static final SerializableVarargsFunction<ExprValue, ExprValue>
varrgs = (v1) -> ANY;
@Mock
FunctionProperties mockProperties;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.expression.function.FunctionDSL.impl;

import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;

class FunctionDSLimplVarargsTest extends FunctionDSLimplTestBase {

@Override
SerializableFunction<FunctionName, Pair<FunctionSignature, FunctionBuilder>>
getImplementationGenerator() {
return impl(varrgs, ANY_TYPE, ANY_TYPE, true);
}

@Override
List<Expression> getSampleArguments() {
return List.of(DSL.literal(ANY));
}

@Override
String getExpected_toString() {
return "sample(ANY)";
}
}
Loading