Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,26 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
(arguments.isEmpty() || arguments.size() == 1)
? Collections.emptyList()
: arguments.subList(1, arguments.size());
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args);
return PlanUtils.makeOver(
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
List<RexNode> nodes =
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
functionName, field, args, context.rexBuilder);
return nodes != null
? PlanUtils.makeOver(
context,
functionName,
nodes.getFirst(),
args.subList(0, nodes.size()),
partitions,
List.of(),
node.getWindowFrame())
: PlanUtils.makeOver(
context,
functionName,
field,
args,
partitions,
List.of(),
node.getWindowFrame());
})
.orElseThrow(
() ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,27 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN;

import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.data.type.WideningTypeRule;
import org.opensearch.sql.exception.ExpressionEvaluationException;

public class CoercionUtils {

public final class CoercionUtils {
/**
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
* of parameter types matches the arguments or if casting fails.
Expand All @@ -31,15 +39,26 @@ public class CoercionUtils {
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();

// TODO: var args?

List<ExprType> sourceTypes =
arguments.stream()
.map(node -> OpenSearchTypeFactory.convertRelDataTypeToExprType(node.getType()))
.collect(Collectors.toList());
// Candidate parameter signatures ordered by decreasing widening distance
PriorityQueue<Pair<List<ExprType>, Integer>> rankedSignatures =
new PriorityQueue<>((left, right) -> Integer.compare(right.getValue(), left.getValue()));
for (List<ExprType> paramTypes : paramTypeCombinations) {
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
if (castedArguments != null) {
return castedArguments;
int distance = distance(sourceTypes, paramTypes);
if (distance == TYPE_EQUAL) {
return castArguments(builder, paramTypes, arguments);
}
Optional.of(distance)
.filter(value -> value != IMPOSSIBLE_WIDENING)
.ifPresent(value -> rankedSignatures.add(Pair.of(paramTypes, value)));
}
return null;
return Optional.ofNullable(rankedSignatures.peek())
.map(Pair::getKey)
.map(paramTypes -> castArguments(builder, paramTypes, arguments))
.orElse(null);
}

/**
Expand Down Expand Up @@ -90,9 +109,12 @@ public class CoercionUtils {
if (!argType.shouldCast(targetType)) {
return arg;
}

if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
if (distance(argType, targetType) != IMPOSSIBLE_WIDENING) {
return builder.makeCast(
OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true);
} else if (argType == ExprCoreType.STRING && NUMBER_TYPES.contains(targetType)) {
return builder.makeCast(
OpenSearchTypeFactory.convertExprTypeToRelDataType(ExprCoreType.DOUBLE), arg, true, true);
}
return null;
}
Expand All @@ -118,12 +140,8 @@ public class CoercionUtils {
for (int i = 1; i < arguments.size(); i++) {
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
try {
if (areDateAndTime(widestType, type)) {
// If one is date and the other is time, we consider timestamp as the widest type
widestType = ExprCoreType.TIMESTAMP;
} else {
widestType = WideningTypeRule.max(widestType, type);
}
final ExprType tempType = widestType;
widestType = resolveCommonType(widestType, type).orElseGet(() -> max(tempType, type));
} catch (ExpressionEvaluationException e) {
// the two types are not compatible, return null
return null;
Expand All @@ -136,4 +154,122 @@ private static boolean areDateAndTime(ExprType type1, ExprType type2) {
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
}

@VisibleForTesting
public static Optional<ExprType> resolveCommonType(ExprType left, ExprType right) {
return COMMON_COERCION_RULES.stream()
.map(rule -> rule.apply(left, right))
.flatMap(Optional::stream)
.findFirst();
}

public static boolean hasString(List<RexNode> rexNodeList) {
return rexNodeList.stream()
.map(RexNode::getType)
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
.anyMatch(t -> t == ExprCoreType.STRING);
}

private static final Set<ExprType> NUMBER_TYPES = ExprCoreType.numberTypes();

private static final List<CoercionRule> COMMON_COERCION_RULES =
List.of(
CoercionRule.of(
(left, right) -> areDateAndTime(left, right),
(left, right) -> ExprCoreType.TIMESTAMP),
CoercionRule.of(
(left, right) -> hasString(left, right) && hasNumber(left, right),
(left, right) -> ExprCoreType.DOUBLE),
CoercionRule.of(
(left, right) -> hasString(left, right) && hasBoolean(left, right),
(left, right) -> ExprCoreType.BOOLEAN));

private static boolean hasString(ExprType left, ExprType right) {
return left == ExprCoreType.STRING || right == ExprCoreType.STRING;
}

private static boolean hasNumber(ExprType left, ExprType right) {
return NUMBER_TYPES.contains(left) || NUMBER_TYPES.contains(right);
}

private static boolean hasBoolean(ExprType left, ExprType right) {
return left == ExprCoreType.BOOLEAN || right == ExprCoreType.BOOLEAN;
}

private record CoercionRule(
BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {

Optional<ExprType> apply(ExprType left, ExprType right) {
return predicate.test(left, right)
? Optional.of(resolver.apply(left, right))
: Optional.empty();
}

static CoercionRule of(
BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
return new CoercionRule(predicate, resolver);
}
}

private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE;
private static final int TYPE_EQUAL = 0;

private static int distance(ExprType type1, ExprType type2) {
return distance(type1, type2, TYPE_EQUAL);
}

private static int distance(ExprType type1, ExprType type2, int distance) {
if (type1 == type2) {
return distance;
} else if (type1 == UNKNOWN) {
return IMPOSSIBLE_WIDENING;
} else if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) {
return 1;
} else {
return type1.getParent().stream()
.map(parentOfType1 -> distance(parentOfType1, type2, distance + 1))
.reduce(Math::min)
.get();
}
}

/**
* The max type among two types. The max is defined as follow if type1 could widen to type2, then
* max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1,
* then throw {@link ExpressionEvaluationException}.
*
* @param type1 type1
* @param type2 type2
* @return the max type among two types.
*/
public static ExprType max(ExprType type1, ExprType type2) {
int type1To2 = distance(type1, type2);
int type2To1 = distance(type2, type1);

if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) {
throw new ExpressionEvaluationException(
String.format("no max type of %s and %s ", type1, type2));
} else {
return type1To2 == Integer.MAX_VALUE ? type1 : type2;
}
}

public static int distance(List<ExprType> sourceTypes, List<ExprType> targetTypes) {
if (sourceTypes.size() != targetTypes.size()) {
return IMPOSSIBLE_WIDENING;
}

int totalDistance = 0;
for (int i = 0; i < sourceTypes.size(); i++) {
ExprType source = sourceTypes.get(i);
ExprType target = targetTypes.get(i);
int distance = distance(source, target);
if (distance == IMPOSSIBLE_WIDENING) {
return IMPOSSIBLE_WIDENING;
} else {
totalDistance += distance;
}
}
return totalDistance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,13 @@ public void registerExternalAggOperator(
aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler));
}

public void validateAggFunctionSignature(
BuiltinFunctionName functionName, RexNode field, List<RexNode> argList) {
public List<RexNode> validateAggFunctionSignature(
BuiltinFunctionName functionName,
RexNode field,
List<RexNode> argList,
RexBuilder rexBuilder) {
var implementation = getImplementation(functionName);
validateFunctionArgs(implementation, functionName, field, argList);
return validateFunctionArgs(implementation, functionName, field, argList, rexBuilder);
}

public RelBuilder.AggCall resolveAgg(
Expand All @@ -432,17 +435,21 @@ public RelBuilder.AggCall resolveAgg(
var implementation = getImplementation(functionName);

// Validation is done based on original argument types to generate error from user perspective.
validateFunctionArgs(implementation, functionName, field, argList);
List<RexNode> nodes =
validateFunctionArgs(implementation, functionName, field, argList, context.rexBuilder);

var handler = implementation.getValue();
return handler.apply(distinct, field, argList, context);
return nodes != null
? handler.apply(distinct, nodes.getFirst(), nodes.subList(1, nodes.size()), context)
: handler.apply(distinct, field, argList, context);
}

static void validateFunctionArgs(
static List<RexNode> validateFunctionArgs(
Pair<CalciteFuncSignature, AggHandler> implementation,
BuiltinFunctionName functionName,
RexNode field,
List<RexNode> argList) {
List<RexNode> argList,
RexBuilder rexBuilder) {
CalciteFuncSignature signature = implementation.getKey();

List<RelDataType> argTypes = new ArrayList<>();
Expand All @@ -455,19 +462,29 @@ static void validateFunctionArgs(
List<RelDataType> additionalArgTypes =
argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).toList();
argTypes.addAll(additionalArgTypes);
List<RexNode> coercionNodes = null;
if (!signature.match(functionName.getName(), argTypes)) {
String errorMessagePattern =
argTypes.size() <= 1
? "Aggregation function %s expects field type {%s}, but got %s"
: "Aggregation function %s expects field type and additional arguments {%s}, but got"
+ " %s";
throw new ExpressionEvaluationException(
String.format(
errorMessagePattern,
functionName,
signature.typeChecker().getAllowedSignatures(),
PlanUtils.getActualSignature(argTypes)));
List<RexNode> fields = new ArrayList<>();
fields.add(field);
fields.addAll(argList);
if (CoercionUtils.hasString(fields)) {
coercionNodes = CoercionUtils.castArguments(rexBuilder, signature.typeChecker(), fields);
}
if (coercionNodes == null) {
String errorMessagePattern =
argTypes.size() <= 1
? "Aggregation function %s expects field type {%s}, but got %s"
: "Aggregation function %s expects field type and additional arguments {%s}, but"
+ " got %s";
throw new ExpressionEvaluationException(
String.format(
errorMessagePattern,
functionName,
signature.typeChecker().getAllowedSignatures(),
PlanUtils.getActualSignature(argTypes)));
}
}
return coercionNodes;
}

private Pair<CalciteFuncSignature, AggHandler> getImplementation(
Expand Down Expand Up @@ -680,8 +697,14 @@ void populate() {

// Register ADDFUNCTION for numeric addition only
registerOperator(ADDFUNCTION, SqlStdOperatorTable.PLUS);
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
registerOperator(SUBTRACTFUNCTION, SqlStdOperatorTable.MINUS);
registerOperator(
SUBTRACT,
SqlStdOperatorTable.MINUS,
PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
registerOperator(
SUBTRACTFUNCTION,
SqlStdOperatorTable.MINUS,
PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);
registerOperator(MULTIPLYFUNCTION, SqlStdOperatorTable.MULTIPLY);
registerOperator(TRUNCATE, SqlStdOperatorTable.TRUNCATE);
Expand Down Expand Up @@ -1341,7 +1364,7 @@ private static PPLTypeChecker wrapSqlOperandTypeChecker(
try {
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, !isUserDefinedFunction);
} catch (IllegalArgumentException | UnsupportedOperationException e) {
logger.debug(
logger.warn(
String.format(
"Failed to create composite type checker for operator: %s. Will skip its type"
+ " checking",
Expand Down
Loading
Loading