Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
import static org.opensearch.sql.utils.DateTimeUtils.findCastType;
import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated;

import java.math.BigDecimal;
import java.util.ArrayList;
Expand All @@ -23,7 +21,6 @@

import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
Expand All @@ -32,7 +29,6 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLambda;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
Expand Down Expand Up @@ -217,11 +213,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) {

@Override
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
RexNode leftCandidate = analyze(node.getLeft(), context);
RexNode rightCandidate = analyze(node.getRight(), context);
SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate);
final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget);
final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget);
RexNode left = analyze(node.getLeft(), context);
RexNode right = analyze(node.getRight(), context);
return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right);
}

Expand Down Expand Up @@ -470,19 +463,6 @@ private List<RelDataType> modifyLambdaTypeByFunction(
}
}

private List<RexNode> castArgument(
List<RexNode> originalArguments, String functionName, ExtendedRexBuilder rexBuilder) {
switch (functionName.toUpperCase(Locale.ROOT)) {
case "REDUCE":
RexLambda call = (RexLambda) originalArguments.get(2);
originalArguments.set(
1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true));
return originalArguments;
default:
return originalArguments;
}
}

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<UnresolvedExpression> args = node.getFuncArgs();
Expand All @@ -509,7 +489,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
}
}

arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder);
RexNode resolvedNode =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,121 @@ private PPLOperandTypes() {}
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
public static final UDFOperandMetadata STRING =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING);
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER);
public static final UDFOperandMetadata INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER);
public static final UDFOperandMetadata NUMERIC =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC);

public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.NUMERIC.or(
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));

public static final UDFOperandMetadata INTEGER_INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
public static final UDFOperandMetadata STRING_STRING =
UDFOperandMetadata.wrap(OperandTypes.STRING_STRING);
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER);
public static final UDFOperandMetadata NUMERIC_NUMERIC =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC);
public static final UDFOperandMetadata STRING_INTEGER =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));

public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
UDFOperandMetadata.wrap(
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.or(
OperandTypes.family(
SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)));

public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family()));

public static final UDFOperandMetadata DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING));
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER));
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER));
public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER));

public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));

public static final UDFOperandMetadata DATETIME_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
public static final UDFOperandMetadata DATETIME_OR_STRING_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)
.or(OperandTypes.CHARACTER_CHARACTER));
public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.STRING_STRING
OperandTypes.CHARACTER_CHARACTER
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME))
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))
.or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
.or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata STRING_TIMESTAMP =
UDFOperandMetadata.wrap(
OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP));
public static final UDFOperandMetadata STRING_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME));
public static final UDFOperandMetadata DATETIME_INTERVAL =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL);
public static final UDFOperandMetadata TIME_TIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME));

public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
OperandTypes.family(
SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER)));
public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING));
public static final UDFOperandMetadata STRING_TIMESTAMP =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP));
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -380,4 +381,13 @@ public static Optional<BuiltinFunctionName> ofWindowFunction(String functionName
return Optional.ofNullable(
WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
}

public static final Set<BuiltinFunctionName> COMPARATORS =
Set.of(
BuiltinFunctionName.EQUAL,
BuiltinFunctionName.NOTEQUAL,
BuiltinFunctionName.LESS,
BuiltinFunctionName.LTE,
BuiltinFunctionName.GREATER,
BuiltinFunctionName.GTE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ public PPLTypeChecker getTypeChecker() {
return typeChecker;
}

public boolean match(FunctionName functionName, List<RelDataType> paramTypeList) {
public boolean match(FunctionName functionName, List<RelDataType> argTypes) {
if (!functionName.equals(this.functionName)) return false;
// For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED),
// the typeChecker will be null because only simple family-based type checks are currently
// supported.
if (typeChecker == null) return true;
return typeChecker.checkOperandTypes(paramTypeList);
return typeChecker.checkOperandTypes(argTypes);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.data.type.WideningTypeRule;
import org.opensearch.sql.exception.ExpressionEvaluationException;

public class CoercionUtils {

/**
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
* of parameter types matches the arguments or if casting fails.
*
* @param builder RexBuilder to create casts
* @param typeChecker PPLTypeChecker that provides the parameter types
* @param arguments List of RexNode arguments to be cast
* @return List of cast RexNode arguments or null if casting fails
*/
public static @Nullable List<RexNode> castArguments(
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();

// TODO: var args?

for (List<ExprType> paramTypes : paramTypeCombinations) {
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
if (castedArguments != null) {
return castedArguments;
}
}
return null;
}

/**
* Widen the arguments to the widest type found among them. If no widest type can be determined,
* returns null.
*
* @param builder RexBuilder to create casts
* @param arguments List of RexNode arguments to be widened
* @return List of widened RexNode arguments or null if no widest type can be determined
*/
public static @Nullable List<RexNode> widenArguments(
RexBuilder builder, List<RexNode> arguments) {
// TODO: Add test on e.g. IP
ExprType widestType = findWidestType(arguments);
if (widestType == null) {
return null; // No widest type found, return null
}
return arguments.stream().map(arg -> cast(builder, widestType, arg)).collect(Collectors.toList());
}

/**
* Casts the arguments to the types specified in paramTypes. Returns null if the number of
* parameters does not match or if casting fails.
*/
private static @Nullable List<RexNode> castArguments(
RexBuilder builder, List<ExprType> paramTypes, List<RexNode> arguments) {
if (paramTypes.size() != arguments.size()) {
return null; // Skip if the number of parameters does not match
}

List<RexNode> castedArguments = new ArrayList<>();
for (int i = 0; i < paramTypes.size(); i++) {
ExprType toType = paramTypes.get(i);
RexNode arg = arguments.get(i);

RexNode castedArg = cast(builder, toType, arg);

if (castedArg == null) {
return null;
}
castedArguments.add(castedArg);
}
return castedArguments;
}

private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
if (!argType.shouldCast(targetType)) {
return arg;
}

if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
}
return null;
}

/**
* Finds the widest type among the given arguments. The widest type is determined by applying the
* widening type rule to each pair of types in the arguments.
*
* @param arguments List of RexNode arguments to find the widest type from
* @return the widest ExprType if found, otherwise null
*/
private static @Nullable ExprType findWidestType(List<RexNode> arguments) {
if (arguments.isEmpty()) {
return null; // No arguments to process
}
ExprType widestType =
OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(0).getType());
if (arguments.size() == 1) {
return widestType;
}

// Iterate pairwise through the arguments and find the widest type
for (int i = 1; i < arguments.size(); i++) {
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
try {
if (areDateAndTime(widestType, type)) {
// If one is date and the other is time, we consider timestamp as the widest type
widestType = ExprCoreType.TIMESTAMP;
} else {
widestType = WideningTypeRule.max(widestType, type);
}
} catch (ExpressionEvaluationException e) {
// the two types are not compatible, return null
return null;
}
}
return widestType;
}

private static boolean areDateAndTime(ExprType type1, ExprType type2) {
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
}
}
Loading
Loading