diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index bb6f16b2e8a..ef6def9d4dd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -445,9 +445,26 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte (arguments.isEmpty() || arguments.size() == 1) ? Collections.emptyList() : arguments.subList(1, arguments.size()); - PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args); - return PlanUtils.makeOver( - context, functionName, field, args, partitions, List.of(), node.getWindowFrame()); + List nodes = + PPLFuncImpTable.INSTANCE.validateAggFunctionSignature( + functionName, field, args, context.rexBuilder); + return nodes != null + ? PlanUtils.makeOver( + context, + functionName, + nodes.getFirst(), + nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()), + partitions, + List.of(), + node.getWindowFrame()) + : PlanUtils.makeOver( + context, + functionName, + field, + args, + partitions, + List.of(), + node.getWindowFrame()); }) .orElseThrow( () -> diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java index f0c6fc84837..ce78d6dec21 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java @@ -5,19 +5,27 @@ package org.opensearch.sql.expression.function; +import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN; + +import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.List; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.BinaryOperator; +import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.data.type.WideningTypeRule; import org.opensearch.sql.exception.ExpressionEvaluationException; -public class CoercionUtils { - +public final class CoercionUtils { /** * Casts the arguments to the types specified in the typeChecker. Returns null if no combination * of parameter types matches the arguments or if casting fails. @@ -31,15 +39,26 @@ public class CoercionUtils { RexBuilder builder, PPLTypeChecker typeChecker, List arguments) { List> paramTypeCombinations = typeChecker.getParameterTypes(); - // TODO: var args? - + List sourceTypes = + arguments.stream() + .map(node -> OpenSearchTypeFactory.convertRelDataTypeToExprType(node.getType())) + .collect(Collectors.toList()); + // Candidate parameter signatures ordered by decreasing widening distance + PriorityQueue, Integer>> rankedSignatures = + new PriorityQueue<>((left, right) -> Integer.compare(right.getValue(), left.getValue())); for (List paramTypes : paramTypeCombinations) { - List castedArguments = castArguments(builder, paramTypes, arguments); - if (castedArguments != null) { - return castedArguments; + int distance = distance(sourceTypes, paramTypes); + if (distance == TYPE_EQUAL) { + return castArguments(builder, paramTypes, arguments); } + Optional.of(distance) + .filter(value -> value != IMPOSSIBLE_WIDENING) + .ifPresent(value -> rankedSignatures.add(Pair.of(paramTypes, value))); } - return null; + return Optional.ofNullable(rankedSignatures.peek()) + .map(Pair::getKey) + .map(paramTypes -> castArguments(builder, paramTypes, arguments)) + .orElse(null); } /** @@ -90,11 +109,16 @@ public class CoercionUtils { if (!argType.shouldCast(targetType)) { return arg; } - - if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) { - return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg); + if (distance(argType, targetType) != IMPOSSIBLE_WIDENING) { + return builder.makeCast( + OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true); } - return null; + return resolveCommonType(argType, targetType) + .map( + exprType -> + builder.makeCast( + OpenSearchTypeFactory.convertExprTypeToRelDataType(exprType), arg, true, true)) + .orElse(null); } /** @@ -118,12 +142,8 @@ public class CoercionUtils { for (int i = 1; i < arguments.size(); i++) { var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType()); try { - if (areDateAndTime(widestType, type)) { - // If one is date and the other is time, we consider timestamp as the widest type - widestType = ExprCoreType.TIMESTAMP; - } else { - widestType = WideningTypeRule.max(widestType, type); - } + final ExprType tempType = widestType; + widestType = resolveCommonType(widestType, type).orElseGet(() -> max(tempType, type)); } catch (ExpressionEvaluationException e) { // the two types are not compatible, return null return null; @@ -136,4 +156,119 @@ private static boolean areDateAndTime(ExprType type1, ExprType type2) { return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME) || (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE); } + + @VisibleForTesting + public static Optional resolveCommonType(ExprType left, ExprType right) { + return COMMON_COERCION_RULES.stream() + .map(rule -> rule.apply(left, right)) + .flatMap(Optional::stream) + .findFirst(); + } + + public static boolean hasString(List rexNodeList) { + return rexNodeList.stream() + .map(RexNode::getType) + .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) + .anyMatch(t -> t == ExprCoreType.STRING); + } + + private static final Set NUMBER_TYPES = ExprCoreType.numberTypes(); + + private static final List COMMON_COERCION_RULES = + List.of( + CoercionRule.of( + (left, right) -> areDateAndTime(left, right), + (left, right) -> ExprCoreType.TIMESTAMP), + CoercionRule.of( + (left, right) -> hasString(left, right) && hasNumber(left, right), + (left, right) -> ExprCoreType.DOUBLE)); + + private static boolean hasString(ExprType left, ExprType right) { + return left == ExprCoreType.STRING || right == ExprCoreType.STRING; + } + + private static boolean hasNumber(ExprType left, ExprType right) { + return NUMBER_TYPES.contains(left) || NUMBER_TYPES.contains(right); + } + + private static boolean hasBoolean(ExprType left, ExprType right) { + return left == ExprCoreType.BOOLEAN || right == ExprCoreType.BOOLEAN; + } + + private record CoercionRule( + BiPredicate predicate, BinaryOperator resolver) { + + Optional apply(ExprType left, ExprType right) { + return predicate.test(left, right) + ? Optional.of(resolver.apply(left, right)) + : Optional.empty(); + } + + static CoercionRule of( + BiPredicate predicate, BinaryOperator resolver) { + return new CoercionRule(predicate, resolver); + } + } + + private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE; + private static final int TYPE_EQUAL = 0; + + private static int distance(ExprType type1, ExprType type2) { + return distance(type1, type2, TYPE_EQUAL); + } + + private static int distance(ExprType type1, ExprType type2, int distance) { + if (type1 == type2) { + return distance; + } else if (type1 == UNKNOWN) { + return IMPOSSIBLE_WIDENING; + } else if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) { + return 1; + } else { + return type1.getParent().stream() + .map(parentOfType1 -> distance(parentOfType1, type2, distance + 1)) + .reduce(Math::min) + .get(); + } + } + + /** + * The max type among two types. The max is defined as follow if type1 could widen to type2, then + * max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1, + * then throw {@link ExpressionEvaluationException}. + * + * @param type1 type1 + * @param type2 type2 + * @return the max type among two types. + */ + public static ExprType max(ExprType type1, ExprType type2) { + int type1To2 = distance(type1, type2); + int type2To1 = distance(type2, type1); + + if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) { + throw new ExpressionEvaluationException( + String.format("no max type of %s and %s ", type1, type2)); + } else { + return type1To2 == Integer.MAX_VALUE ? type1 : type2; + } + } + + public static int distance(List sourceTypes, List targetTypes) { + if (sourceTypes.size() != targetTypes.size()) { + return IMPOSSIBLE_WIDENING; + } + + int totalDistance = 0; + for (int i = 0; i < sourceTypes.size(); i++) { + ExprType source = sourceTypes.get(i); + ExprType target = targetTypes.get(i); + int distance = distance(source, target); + if (distance == IMPOSSIBLE_WIDENING) { + return IMPOSSIBLE_WIDENING; + } else { + totalDistance += distance; + } + } + return totalDistance; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index afe5df01cf1..df6def6d097 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -259,6 +259,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlTrimFunction.Flag; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; +import org.apache.calcite.sql.type.FamilyOperandTypeChecker; import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; @@ -417,10 +418,13 @@ public void registerExternalAggOperator( aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler)); } - public void validateAggFunctionSignature( - BuiltinFunctionName functionName, RexNode field, List argList) { + public List validateAggFunctionSignature( + BuiltinFunctionName functionName, + RexNode field, + List argList, + RexBuilder rexBuilder) { var implementation = getImplementation(functionName); - validateFunctionArgs(implementation, functionName, field, argList); + return validateFunctionArgs(implementation, functionName, field, argList, rexBuilder); } public RelBuilder.AggCall resolveAgg( @@ -432,17 +436,21 @@ public RelBuilder.AggCall resolveAgg( var implementation = getImplementation(functionName); // Validation is done based on original argument types to generate error from user perspective. - validateFunctionArgs(implementation, functionName, field, argList); + List nodes = + validateFunctionArgs(implementation, functionName, field, argList, context.rexBuilder); var handler = implementation.getValue(); - return handler.apply(distinct, field, argList, context); + return nodes != null + ? handler.apply(distinct, nodes.getFirst(), nodes.subList(1, nodes.size()), context) + : handler.apply(distinct, field, argList, context); } - static void validateFunctionArgs( + static List validateFunctionArgs( Pair implementation, BuiltinFunctionName functionName, RexNode field, - List argList) { + List argList, + RexBuilder rexBuilder) { CalciteFuncSignature signature = implementation.getKey(); List argTypes = new ArrayList<>(); @@ -455,19 +463,29 @@ static void validateFunctionArgs( List additionalArgTypes = argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).toList(); argTypes.addAll(additionalArgTypes); + List coercionNodes = null; if (!signature.match(functionName.getName(), argTypes)) { - String errorMessagePattern = - argTypes.size() <= 1 - ? "Aggregation function %s expects field type {%s}, but got %s" - : "Aggregation function %s expects field type and additional arguments {%s}, but got" - + " %s"; - throw new ExpressionEvaluationException( - String.format( - errorMessagePattern, - functionName, - signature.typeChecker().getAllowedSignatures(), - PlanUtils.getActualSignature(argTypes))); + List fields = new ArrayList<>(); + fields.add(field); + fields.addAll(argList); + if (CoercionUtils.hasString(fields)) { + coercionNodes = CoercionUtils.castArguments(rexBuilder, signature.typeChecker(), fields); + } + if (coercionNodes == null) { + String errorMessagePattern = + argTypes.size() <= 1 + ? "Aggregation function %s expects field type {%s}, but got %s" + : "Aggregation function %s expects field type and additional arguments {%s}, but" + + " got %s"; + throw new ExpressionEvaluationException( + String.format( + errorMessagePattern, + functionName, + signature.typeChecker().getAllowedSignatures(), + PlanUtils.getActualSignature(argTypes))); + } } + return coercionNodes; } private Pair getImplementation( @@ -680,8 +698,14 @@ void populate() { // Register ADDFUNCTION for numeric addition only registerOperator(ADDFUNCTION, SqlStdOperatorTable.PLUS); - registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS); - registerOperator(SUBTRACTFUNCTION, SqlStdOperatorTable.MINUS); + registerOperator( + SUBTRACTFUNCTION, + SqlStdOperatorTable.MINUS, + PPLTypeChecker.wrapFamily((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC)); + registerOperator( + SUBTRACT, + SqlStdOperatorTable.MINUS, + PPLTypeChecker.wrapFamily((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC)); registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY); registerOperator(MULTIPLYFUNCTION, SqlStdOperatorTable.MULTIPLY); registerOperator(TRUNCATE, SqlStdOperatorTable.TRUNCATE); @@ -739,13 +763,37 @@ void populate() { registerOperator(ASIN, SqlStdOperatorTable.ASIN); registerOperator(ATAN, SqlStdOperatorTable.ATAN); registerOperator(ATAN2, SqlStdOperatorTable.ATAN2); - registerOperator(CEIL, SqlStdOperatorTable.CEIL); - registerOperator(CEILING, SqlStdOperatorTable.CEIL); + // TODO, workaround to support sequence CompositeOperandTypeChecker. + registerOperator( + CEIL, + SqlStdOperatorTable.CEIL, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC_OR_INTERVAL.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.ANY)), + false)); + // TODO, workaround to support sequence CompositeOperandTypeChecker. + registerOperator( + CEILING, + SqlStdOperatorTable.CEIL, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC_OR_INTERVAL.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.ANY)), + false)); registerOperator(COS, SqlStdOperatorTable.COS); registerOperator(COT, SqlStdOperatorTable.COT); registerOperator(DEGREES, SqlStdOperatorTable.DEGREES); registerOperator(EXP, SqlStdOperatorTable.EXP); - registerOperator(FLOOR, SqlStdOperatorTable.FLOOR); + // TODO, workaround to support sequence CompositeOperandTypeChecker. + registerOperator( + FLOOR, + SqlStdOperatorTable.FLOOR, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC_OR_INTERVAL.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.ANY)), + false)); registerOperator(LN, SqlStdOperatorTable.LN); registerOperator(LOG10, SqlStdOperatorTable.LOG10); registerOperator(PI, SqlStdOperatorTable.PI); @@ -753,7 +801,15 @@ void populate() { registerOperator(POWER, SqlStdOperatorTable.POWER); registerOperator(RADIANS, SqlStdOperatorTable.RADIANS); registerOperator(RAND, SqlStdOperatorTable.RAND); - registerOperator(ROUND, SqlStdOperatorTable.ROUND); + // TODO, workaround to support sequence CompositeOperandTypeChecker. + registerOperator( + ROUND, + SqlStdOperatorTable.ROUND, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC.or( + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER)), + false)); registerOperator(SIGN, SqlStdOperatorTable.SIGN); registerOperator(SIGNUM, SqlStdOperatorTable.SIGN); registerOperator(SIN, SqlStdOperatorTable.SIN); diff --git a/core/src/test/java/org/opensearch/sql/expression/function/CoercionUtilsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/CoercionUtilsTest.java new file mode 100644 index 00000000000..30d827f1ecc --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/CoercionUtilsTest.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.*; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import java.util.stream.Stream; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; + +class CoercionUtilsTest { + + private static final RexBuilder REX_BUILDER = new RexBuilder(OpenSearchTypeFactory.TYPE_FACTORY); + + private static RexNode nullLiteral(ExprCoreType type) { + return REX_BUILDER.makeNullLiteral(OpenSearchTypeFactory.convertExprTypeToRelDataType(type)); + } + + private static Stream commonWidestTypeArguments() { + return Stream.of( + Arguments.of(STRING, INTEGER, DOUBLE), + Arguments.of(INTEGER, STRING, DOUBLE), + Arguments.of(STRING, DOUBLE, DOUBLE), + Arguments.of(INTEGER, BOOLEAN, null)); + } + + @ParameterizedTest + @MethodSource("commonWidestTypeArguments") + public void findCommonWidestType( + ExprCoreType left, ExprCoreType right, ExprCoreType expectedCommonType) { + assertEquals( + expectedCommonType, CoercionUtils.resolveCommonType(left, right).orElseGet(() -> null)); + } + + @Test + void castArgumentsReturnsExactMatchWhenAvailable() { + PPLTypeChecker typeChecker = new StubTypeChecker(List.of(List.of(INTEGER), List.of(DOUBLE))); + List arguments = List.of(nullLiteral(INTEGER)); + + List result = CoercionUtils.castArguments(REX_BUILDER, typeChecker, arguments); + + assertNotNull(result); + assertEquals(1, result.size()); + assertEquals( + INTEGER, OpenSearchTypeFactory.convertRelDataTypeToExprType(result.getFirst().getType())); + } + + @Test + void castArgumentsFallsBackToWidestCandidate() { + PPLTypeChecker typeChecker = + new StubTypeChecker(List.of(List.of(ExprCoreType.LONG), List.of(DOUBLE))); + List arguments = List.of(nullLiteral(STRING)); + + List result = CoercionUtils.castArguments(REX_BUILDER, typeChecker, arguments); + + assertNotNull(result); + assertEquals( + DOUBLE, OpenSearchTypeFactory.convertRelDataTypeToExprType(result.getFirst().getType())); + } + + @Test + void castArgumentsReturnsNullWhenNoCompatibleSignatureExists() { + PPLTypeChecker typeChecker = new StubTypeChecker(List.of(List.of(ExprCoreType.GEO_POINT))); + List arguments = List.of(nullLiteral(INTEGER)); + + assertNull(CoercionUtils.castArguments(REX_BUILDER, typeChecker, arguments)); + } + + private static class StubTypeChecker implements PPLTypeChecker { + private final List> signatures; + + private StubTypeChecker(List> signatures) { + this.signatures = signatures; + } + + @Override + public boolean checkOperandTypes(List types) { + return false; + } + + @Override + public String getAllowedSignatures() { + return ""; + } + + @Override + public List> getParameterTypes() { + return signatures; + } + } +} diff --git a/docs/category.json b/docs/category.json index b46c36afdef..49529b08bdc 100644 --- a/docs/category.json +++ b/docs/category.json @@ -63,10 +63,11 @@ "user/ppl/functions/math.rst", "user/ppl/functions/relevance.rst", "user/ppl/functions/string.rst", + "user/ppl/functions/conversion.rst", "user/ppl/general/datatypes.rst", "user/ppl/general/identifiers.rst" ], "bash_settings": [ "user/ppl/admin/settings.rst" ] -} \ No newline at end of file +} diff --git a/docs/user/ppl/functions/conversion.rst b/docs/user/ppl/functions/conversion.rst index dbe4403540c..849d2334e41 100644 --- a/docs/user/ppl/functions/conversion.rst +++ b/docs/user/ppl/functions/conversion.rst @@ -46,7 +46,7 @@ Cast to string example:: +-------+------+------------+ | cbool | cint | cdate | |-------+------+------------| - | true | 1 | 2012-08-07 | + | TRUE | 1 | 2012-08-07 | +-------+------+------------+ Cast to number example:: @@ -78,3 +78,42 @@ Cast function can be chained:: |-------| | True | +-------+ + + +IMPLICIT (AUTO) TYPE CONVERSION +------------------------------- + +Implicit conversion is automatic casting. When a function does not have an exact match for the +input types, the engine looks for another signature that can safely work with the values. It picks +the option that requires the least stretching of the original types, so you can mix literals and +fields without adding ``CAST`` everywhere. + +String to numeric +>>>>>>>>>>>>>>>>> + +When a string stands in for a number we simply parse the text: + +- The value must be something like ``"3.14"`` or ``"42"``. Anything else causes the query to fail. +- If a string appears next to numeric arguments, it is treated as a ``DOUBLE`` so the numeric + overload of the function can run. + +Use string in arithmetic operator example :: + + os> source=people | eval divide="5"/10, multiply="5" * 10, add="5" + 10, minus="5" - 10, concat="5" + "5" | fields divide, multiply, add, minus, concat + fetched rows / total rows = 1/1 + +--------+----------+------+-------+--------+ + | divide | multiply | add | minus | concat | + |--------+----------+------+-------+--------| + | 0.5 | 50.0 | 15.0 | -5.0 | 55 | + +--------+----------+------+-------+--------+ + +Use string in comparison operator example :: + + os> source=people | eval e="1000"==1000, en="1000"!=1000, ed="1000"==1000.0, edn="1000"!=1000.0, l="1000">999, ld="1000">999.9, i="malformed"==1000 | fields e, en, ed, edn, l, ld, i + fetched rows / total rows = 1/1 + +------+-------+------+-------+------+------+------+ + | e | en | ed | edn | l | ld | i | + |------+-------+------+-------+------+------+------| + | True | False | True | False | True | True | null | + +------+-------+------+-------+------+------+------+ + diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteDateTimeFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteDateTimeFunctionIT.java index 8eee5c01f7c..ef0c0599b57 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteDateTimeFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteDateTimeFunctionIT.java @@ -132,22 +132,13 @@ public void testStrftimeWithExpressions() throws IOException { @Test public void testStrftimeStringHandling() throws IOException { - try { - executeQuery( - String.format( - "source=%s | eval result = strftime('1521467703', '%s') | fields result | head 1", - TEST_INDEX_DATE, "%Y-%m-%d")); - fail("String literals should not be accepted by strftime"); - } catch (Exception e) { - // Expected - string literals are not supported - // The error occurs because Calcite tries to convert the string to a timestamp - // which doesn't match the expected timestamp format - assertTrue( - "Error should indicate format issue or type problem", - e.getMessage().contains("unsupported format") - || e.getMessage().contains("timestamp") - || e.getMessage().contains("500")); - } + // Test 1: Support string literal + JSONObject result0 = + executeQuery( + String.format( + "source=%s | eval result = strftime('1521467703', '%s') | fields result | head 1", + TEST_INDEX_DATE, "%Y-%m-%d")); + verifyDataRows(result0, rows("2018-03-19")); // Test 2: The correct approach - use numeric literals directly JSONObject result1 = diff --git a/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/4356.yml b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/4356.yml new file mode 100644 index 00000000000..01d8dd6de16 --- /dev/null +++ b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/4356.yml @@ -0,0 +1,192 @@ +setup: + - do: + query.settings: + body: + transient: + plugins.calcite.enabled : true + + - do: + indices.create: + index: log00001 + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + v: + type: text + strnum: + type: keyword + vint: + type: integer + vdouble: + type: double + vboolean: + type: boolean + + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "log00001", "_id": 1}}' + - '{"v": "value=1", "a": 1, "vint": 1, "vdouble": 1.0, "strnum": "1"}' + - '{"index": {"_index": "log00001", "_id": 2}}' + - '{"v": "value=1.5", "a": 2, "vint": 1, "vdouble": 1.5, "strnum": "2"}' + - '{"index": {"_index": "log00001", "_id": 3}}' + - '{"v": "value=true", "a": 3, "vint": 1, "vdouble": 1.0, "vboolean":true, "strnum": "3"}' + - '{"index": {"_index": "log00001", "_id": 4}}' + - '{"v": "value=abcde", "a": 4, "vint": 1, "vdouble": 1.0, "strnum": "malformed"}' + - do: + indices.create: + index: log00002 + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + id: + type: integer + + + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "log00002", "_id": 1}}' + - '{"id": 1}' + +--- +teardown: + - do: + query.settings: + body: + transient: + plugins.calcite.enabled : false + - do: + indices.delete: + index: log00001 + ignore_unavailable: true + +--- +"Extracted value participate in arithmetic operator": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eval m=digits * 10 | eval d=digits/10 | sort a | fields m, d + - match: {"schema": [{"name": "m", "type": "double"}, {"name": "d", "type": "double"}]} + - match: {"datarows": [[10.0, 0.1], [15.0, 0.15], [null, null], [null, null]]} + + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eval m=digits + digits, d=digits * digits | sort a | fields m, d + - match: { "schema": [ { "name": "m", "type": "string" }, { "name": "d", "type": "double" } ] } + - match: { "datarows": [ [ "11", 1.0 ], [ "1.51.5", 2.25 ], [ "truetrue", null ], [ "abcdeabcde", null ] ] } + + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00002 | eval m="5" - 10 | eval r=round("1.5", 1) | eval f=floor("5.2") | eval c=ceil("5.2") | fields m, r, f, c + - match: { "schema": [ { "name": "m", "type": "double" }, { "name": "r", "type": "double" }, { "name": "f", "type": "double" }, { "name": "c", "type": "double" }] } + - match: { "datarows": [ [ -5.0, 1.5, 5.0, 6.0] ] } + +--- +"Extracted value participate in comparison operator": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eval i=digits==vint, d=digits==vdouble, b=digits==vboolean| fields i, d, b + - match: {"schema": [{"name": "i", "type": "boolean"}, {"name": "d", "type": "boolean"}, {"name": "b", "type": "boolean"}]} + - match: {"datarows": [[true,true,null], [false,true,null], [null, null, true], [null, null, null]]} + + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00002 | eval e='1000'==1000, en='1000'!=1000, ed='1000'==1000.0, edn='1000'!=1000.0, l='1000'>999, ld='1000'>999.9, i="malformed"==1000 | fields e, en, ed, edn, l, ld, i + - match: {"schema": [{"name": "e", "type": "boolean"}, {"name": "en", "type": "boolean"}, {"name": "ed", "type": "boolean"}, {"name": "edn", "type": "boolean"}, {"name": "l", "type": "boolean"}, {"name": "ld", "type": "boolean"}, {"name": "i", "type": "boolean"}]} + - match: {"datarows": [[true, false, true, false, true, true, null]]} + +--- +"Extracted value participate in string func": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eval r=concat('v-', digits) | sort a | fields r + - match: {"schema": [{"name": "r", "type": "string"}]} + - match: {"datarows": [["v-1"], ["v-1.5"], ["v-true"], ["v-abcde"]]} + + +--- +"Extracted value participate in condition func": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eval isNull=isnull(digits) | fields isNull + - match: {"schema": [{"name": "isNull", "type": "boolean"}]} + - match: {"datarows": [[false], [false], [false], [false]]} + +--- +"Extracted value participate in aggregation func": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | stats count(digits) as cnt, sum(digits) as sum, avg(digits) as avg + - match: {"schema": [{"name": "cnt", "type": "bigint"}, {"name": "sum", "type": "double"}, {"name": "avg", "type": "double"}]} + - match: {"datarows": [[4, 2.5, 1.25]]} + + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | rex field=v 'value=(?[\\w\\d\\.]*)' | eventstats sum(digits) as sum, count(digits) as cnt| fields a, sum, cnt + - match: { "schema": [ { "name": "a", "type": "bigint" }, { "name": "sum", "type": "double" }, { "name": "cnt", "type": "bigint" } ] } + - match: { "datarows": [ [1, 2.5, 4], [2, 2.5, 4], [3, 2.5, 4], [4, 2.5, 4] ] } + +--- +"Safe cast keyword to int": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: source=log00001 | stats count(strnum) as cnt, sum(strnum) as sum, avg(strnum) as avg + - match: {"schema": [{"name": "cnt", "type": "bigint"}, {"name": "sum", "type": "double"}, {"name": "avg", "type": "double"}]} +# Notice: Count is calculated on string value, sum and avg are calculated on numeric value, this is why sum/count!=avg + - match: {"datarows": [[4, 6.0, 2.0]]} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregateFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregateFunctionTypeTest.java index 1e1109c256f..b363be9bee1 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregateFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregateFunctionTypeTest.java @@ -17,48 +17,48 @@ public CalcitePPLAggregateFunctionTypeTest() { @Test public void testAvgWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | stats avg(ENAME) as avg_name", - "Aggregation function AVG expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | stats avg(HIREDATE) as avg_name", + "Aggregation function AVG expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testVarsampWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | stats var_samp(ENAME) as varsamp_name", - "Aggregation function VARSAMP expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | stats var_samp(HIREDATE) as varsamp_name", + "Aggregation function VARSAMP expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testVarpopWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | stats var_pop(ENAME) as varpop_name", - "Aggregation function VARPOP expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | stats var_pop(HIREDATE) as varpop_name", + "Aggregation function VARPOP expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testStddevSampWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | stats stddev_samp(ENAME) as stddev_name", + "source=EMP | stats stddev_samp(HIREDATE) as stddev_name", "Aggregation function STDDEV_SAMP expects field type {[INTEGER]|[DOUBLE]}, but got" - + " [STRING]"); + + " [DATE]"); } @Test public void testStddevPopWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | stats stddev_pop(ENAME) as stddev_name", + "source=EMP | stats stddev_pop(HIREDATE) as stddev_name", "Aggregation function STDDEV_POP expects field type {[INTEGER]|[DOUBLE]}, but got" - + " [STRING]"); + + " [DATE]"); } @Test public void testPercentileApproxWithWrongArgType() { // First argument should be numeric verifyQueryThrowsException( - "source=EMP | stats percentile_approx(ENAME, 50) as percentile", + "source=EMP | stats percentile_approx(HIREDATE, 50) as percentile", "Aggregation function PERCENTILE_APPROX expects field type and additional arguments" + " {[INTEGER,INTEGER]|[INTEGER,DOUBLE]|[DOUBLE,INTEGER]|[DOUBLE,DOUBLE]|[INTEGER,INTEGER,INTEGER]|[INTEGER,INTEGER,DOUBLE]|[INTEGER,DOUBLE,INTEGER]|[INTEGER,DOUBLE,DOUBLE]|[DOUBLE,INTEGER,INTEGER]|[DOUBLE,INTEGER,DOUBLE]|[DOUBLE,DOUBLE,INTEGER]|[DOUBLE,DOUBLE,DOUBLE]}," - + " but got [STRING,INTEGER]"); + + " but got [DATE,INTEGER]"); } @Test @@ -155,10 +155,10 @@ public void testPercentileWithMissingParametersThrowsException() { @Test public void testPercentileWithInvalidParameterTypesThrowsException() { verifyQueryThrowsException( - "source=EMP | stats percentile(EMPNO, 50, ENAME)", + "source=EMP | stats percentile(EMPNO, 50, HIREDATE)", "Aggregation function PERCENTILE_APPROX expects field type and additional arguments" + " {[INTEGER,INTEGER]|[INTEGER,DOUBLE]|[DOUBLE,INTEGER]|[DOUBLE,DOUBLE]|[INTEGER,INTEGER,INTEGER]|[INTEGER,INTEGER,DOUBLE]|[INTEGER,DOUBLE,INTEGER]|[INTEGER,DOUBLE,DOUBLE]|[DOUBLE,INTEGER,INTEGER]|[DOUBLE,INTEGER,DOUBLE]|[DOUBLE,DOUBLE,INTEGER]|[DOUBLE,DOUBLE,DOUBLE]}," - + " but got [SHORT,INTEGER,STRING]"); + + " but got [SHORT,INTEGER,DATE]"); } @Test diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTypeTest.java index a6535755435..24bd9ac18d0 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTypeTest.java @@ -107,37 +107,37 @@ public void testLatestWithTooManyParametersThrowsException() { @Test public void testAvgWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | eventstats avg(ENAME) as avg_name", - "Aggregation function AVG expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | eventstats avg(HIREDATE) as avg_name", + "Aggregation function AVG expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testVarsampWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | eventstats var_samp(ENAME) as varsamp_name", - "Aggregation function VARSAMP expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | eventstats var_samp(HIREDATE) as varsamp_name", + "Aggregation function VARSAMP expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testVarpopWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | eventstats var_pop(ENAME) as varpop_name", - "Aggregation function VARPOP expects field type {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | eventstats var_pop(HIREDATE) as varpop_name", + "Aggregation function VARPOP expects field type {[INTEGER]|[DOUBLE]}, but got [DATE]"); } @Test public void testStddevSampWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | eventstats stddev_samp(ENAME) as stddev_name", + "source=EMP | eventstats stddev_samp(HIREDATE) as stddev_name", "Aggregation function STDDEV_SAMP expects field type {[INTEGER]|[DOUBLE]}, but got" - + " [STRING]"); + + " [DATE]"); } @Test public void testStddevPopWithWrongArgType() { verifyQueryThrowsException( - "source=EMP | eventstats stddev_pop(ENAME) as stddev_name", + "source=EMP | eventstats stddev_pop(HIREDATE) as stddev_name", "Aggregation function STDDEV_POP expects field type {[INTEGER]|[DOUBLE]}, but got" - + " [STRING]"); + + " [DATE]"); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index fc840169ee2..9513558952f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -45,12 +45,8 @@ public void testTimeDiffWithUdtInputType() { public void testComparisonWithDifferentType() { getRelNode("source=EMP | where EMPNO > 6 | fields ENAME"); getRelNode("source=EMP | where ENAME <= 'Jack' | fields ENAME"); - verifyQueryThrowsException( - "source=EMP | where ENAME < 6 | fields ENAME", - // Temporary fix for the error message as LESS function has two variants. Will remove - // [IP,IP] when merging the two variants. - "LESS function expects {[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}, but got" - + " [STRING,INTEGER]"); + // LogicalFilter(condition=[<(SAFE_CAST($1), 6.0E0)]) + getRelNode("source=EMP | where ENAME < 6 | fields ENAME"); } @Test @@ -151,8 +147,8 @@ public void testSha2WrongArgShouldThrow() { @Test public void testSqrtWithWrongArg() { verifyQueryThrowsException( - "source=EMP | head 1 | eval sqrt_name = sqrt(ENAME) | fields sqrt_name", - "SQRT function expects {[INTEGER]|[DOUBLE]}, but got [STRING]"); + "source=EMP | head 1 | eval sqrt_name = sqrt(HIREDATE) | fields sqrt_name", + "SQRT function expects {[INTEGER]|[DOUBLE]}, but got [DATE]"); } // Test UDF registered with PPL builtin operators: registerOperator(MOD, PPLBuiltinOperators.MOD);