Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,26 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
(arguments.isEmpty() || arguments.size() == 1)
? Collections.emptyList()
: arguments.subList(1, arguments.size());
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args);
return PlanUtils.makeOver(
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
List<RexNode> nodes =
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
functionName, field, args, context.rexBuilder);
return nodes != null
? PlanUtils.makeOver(
context,
functionName,
nodes.get(0),
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
partitions,
List.of(),
node.getWindowFrame())
: PlanUtils.makeOver(
context,
functionName,
field,
args,
partitions,
List.of(),
node.getWindowFrame());
})
.orElseThrow(
() ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN;

import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.data.type.WideningTypeRule;
import org.opensearch.sql.exception.ExpressionEvaluationException;

public class CoercionUtils {

public final class CoercionUtils {
/**
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
* of parameter types matches the arguments or if casting fails.
Expand All @@ -32,15 +39,26 @@ public class CoercionUtils {
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();

// TODO: var args?

List<ExprType> sourceTypes =
arguments.stream()
.map(node -> OpenSearchTypeFactory.convertRelDataTypeToExprType(node.getType()))
.collect(Collectors.toList());
// Candidate parameter signatures ordered by decreasing widening distance
PriorityQueue<Pair<List<ExprType>, Integer>> rankedSignatures =
new PriorityQueue<>((left, right) -> Integer.compare(right.getValue(), left.getValue()));
for (List<ExprType> paramTypes : paramTypeCombinations) {
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
if (castedArguments != null) {
return castedArguments;
int distance = distance(sourceTypes, paramTypes);
if (distance == TYPE_EQUAL) {
return castArguments(builder, paramTypes, arguments);
}
Optional.of(distance)
.filter(value -> value != IMPOSSIBLE_WIDENING)
.ifPresent(value -> rankedSignatures.add(Pair.of(paramTypes, value)));
}
return null;
return Optional.ofNullable(rankedSignatures.peek())
.map(Pair::getKey)
.map(paramTypes -> castArguments(builder, paramTypes, arguments))
.orElse(null);
}

/**
Expand Down Expand Up @@ -91,11 +109,16 @@ public class CoercionUtils {
if (!argType.shouldCast(targetType)) {
return arg;
}

if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
if (distance(argType, targetType) != IMPOSSIBLE_WIDENING) {
return builder.makeCast(
OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true);
}
return null;
return resolveCommonType(argType, targetType)
.map(
exprType ->
builder.makeCast(
OpenSearchTypeFactory.convertExprTypeToRelDataType(exprType), arg, true, true))
.orElse(null);
}

/**
Expand All @@ -119,12 +142,8 @@ public class CoercionUtils {
for (int i = 1; i < arguments.size(); i++) {
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
try {
if (areDateAndTime(widestType, type)) {
// If one is date and the other is time, we consider timestamp as the widest type
widestType = ExprCoreType.TIMESTAMP;
} else {
widestType = WideningTypeRule.max(widestType, type);
}
final ExprType tempType = widestType;
widestType = resolveCommonType(widestType, type).orElseGet(() -> max(tempType, type));
} catch (ExpressionEvaluationException e) {
// the two types are not compatible, return null
return null;
Expand All @@ -137,4 +156,125 @@ private static boolean areDateAndTime(ExprType type1, ExprType type2) {
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
}

@VisibleForTesting
public static Optional<ExprType> resolveCommonType(ExprType left, ExprType right) {
return COMMON_COERCION_RULES.stream()
.map(rule -> rule.apply(left, right))
.flatMap(Optional::stream)
.findFirst();
}

public static boolean hasString(List<RexNode> rexNodeList) {
return rexNodeList.stream()
.map(RexNode::getType)
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
.anyMatch(t -> t == ExprCoreType.STRING);
}

private static final Set<ExprType> NUMBER_TYPES = ExprCoreType.numberTypes();

private static final List<CoercionRule> COMMON_COERCION_RULES =
List.of(
CoercionRule.of(
(left, right) -> areDateAndTime(left, right),
(left, right) -> ExprCoreType.TIMESTAMP),
CoercionRule.of(
(left, right) -> hasString(left, right) && hasNumber(left, right),
(left, right) -> ExprCoreType.DOUBLE));

private static boolean hasString(ExprType left, ExprType right) {
return left == ExprCoreType.STRING || right == ExprCoreType.STRING;
}

private static boolean hasNumber(ExprType left, ExprType right) {
return NUMBER_TYPES.contains(left) || NUMBER_TYPES.contains(right);
}

private static boolean hasBoolean(ExprType left, ExprType right) {
return left == ExprCoreType.BOOLEAN || right == ExprCoreType.BOOLEAN;
}

private static class CoercionRule {
private final BiPredicate<ExprType, ExprType> predicate;
private final BinaryOperator<ExprType> resolver;

public CoercionRule(BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
this.predicate = predicate;
this.resolver = resolver;
}

Optional<ExprType> apply(ExprType left, ExprType right) {
return predicate.test(left, right)
? Optional.of(resolver.apply(left, right))
: Optional.empty();
}

static CoercionRule of(
BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
return new CoercionRule(predicate, resolver);
}
}

private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE;
private static final int TYPE_EQUAL = 0;

private static int distance(ExprType type1, ExprType type2) {
return distance(type1, type2, TYPE_EQUAL);
}

private static int distance(ExprType type1, ExprType type2, int distance) {
if (type1 == type2) {
return distance;
} else if (type1 == UNKNOWN) {
return IMPOSSIBLE_WIDENING;
} else if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) {
return 1;
} else {
return type1.getParent().stream()
.map(parentOfType1 -> distance(parentOfType1, type2, distance + 1))
.reduce(Math::min)
.get();
}
}

/**
* The max type among two types. The max is defined as follow if type1 could widen to type2, then
* max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1,
* then throw {@link ExpressionEvaluationException}.
*
* @param type1 type1
* @param type2 type2
* @return the max type among two types.
*/
public static ExprType max(ExprType type1, ExprType type2) {
int type1To2 = distance(type1, type2);
int type2To1 = distance(type2, type1);

if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) {
throw new ExpressionEvaluationException(
String.format("no max type of %s and %s ", type1, type2));
} else {
return type1To2 == Integer.MAX_VALUE ? type1 : type2;
}
}

public static int distance(List<ExprType> sourceTypes, List<ExprType> targetTypes) {
if (sourceTypes.size() != targetTypes.size()) {
return IMPOSSIBLE_WIDENING;
}

int totalDistance = 0;
for (int i = 0; i < sourceTypes.size(); i++) {
ExprType source = sourceTypes.get(i);
ExprType target = targetTypes.get(i);
int distance = distance(source, target);
if (distance == IMPOSSIBLE_WIDENING) {
return IMPOSSIBLE_WIDENING;
} else {
totalDistance += distance;
}
}
return totalDistance;
}
}
Loading
Loading