diff --git a/docs/changelog/111552.yaml b/docs/changelog/111552.yaml new file mode 100644 index 0000000000000..d9991788d4fa9 --- /dev/null +++ b/docs/changelog/111552.yaml @@ -0,0 +1,5 @@ +pr: 111552 +summary: Siem ea 9521 improve test +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java index c96599eaf8236..5f6bc8361c1bb 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; +import java.lang.ArithmeticException; import java.lang.Override; import java.lang.String; import java.util.function.Function; @@ -59,7 +60,12 @@ public DoubleBlock eval(int positionCount, DoubleBlock blockBlock) { result.appendNull(); continue position; } - MvPSeriesWeightedSum.process(result, p, blockBlock, this.sum, this.p); + try { + MvPSeriesWeightedSum.process(result, p, blockBlock, this.sum, this.p); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java index 60eab9fd4ad74..212f626090789 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java @@ -37,8 +37,9 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; /** * Reduce a multivalued field to a single valued field containing the weighted sum of all element applying the P series function. @@ -89,14 +90,18 @@ protected TypeResolution resolveType() { return resolution; } - resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double") - .and(isNotNullAndFoldable(p, sourceText(), SECOND)); - + resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double"); if (resolution.unresolved()) { return resolution; } - return resolution; + if (p.dataType() == NULL) { + // If the type is `null` this parameter doesn't have to be foldable. It's effectively foldable anyway. + // TODO figure out if the tests are wrong here, or if null is really different from foldable null + return resolution; + } + + return isFoldable(p, sourceText(), SECOND); } @Override @@ -130,10 +135,13 @@ protected NodeInfo info() { @Override public DataType dataType() { + if (p.dataType() == NULL) { + return NULL; + } return field.dataType(); } - @Evaluator(extraName = "Double") + @Evaluator(extraName = "Double", warnExceptions = ArithmeticException.class) static void process( DoubleBlock.Builder builder, int position, @@ -149,7 +157,11 @@ static void process( double current_score = block.getDouble(i) / Math.pow(i - start + 1, p); sum.add(current_score); } - builder.appendDouble(sum.value()); + if (Double.isFinite(sum.value())) { + builder.appendDouble(sum.value()); + } else { + throw new ArithmeticException("double overflow"); + } } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java index 3e73a5606652d..66b587f257e2e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java @@ -304,7 +304,7 @@ public final void testFold() { assertTypeResolutionFailure(expression); return; } - assertFalse(expression.typeResolved().unresolved()); + assertFalse("expected resolved", expression.typeResolved().unresolved()); Expression nullOptimized = new FoldNull().rule(expression); assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType())); assertTrue(nullOptimized.foldable()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java index 0f277485b874d..156fc4bfe7c36 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java @@ -15,13 +15,14 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.hamcrest.Matcher; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; public class MvPSeriesWeightedSumTests extends AbstractScalarFunctionTestCase { public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier testCaseSupplier) { @@ -31,10 +32,21 @@ public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier parameters() { List cases = new ArrayList<>(); - doubles(cases); - // TODO use parameterSuppliersFromTypedDataWithDefaultChecks instead of parameterSuppliersFromTypedData and fix errors + cases = randomizeBytesRefsOffset(cases); + cases = anyNullIsNull( + cases, + (nullPosition, nullValueDataType, original) -> nullValueDataType == DataType.NULL ? DataType.NULL : original.expectedType(), + (nullPosition, nullData, original) -> { + if (nullData.isForceLiteral()) { + return equalTo("LiteralsEvaluator[lit=null]"); + } + return nullData.type() == DataType.NULL ? equalTo("LiteralsEvaluator[lit=null]") : original; + } + ); + cases = errorsForCasesWithoutExamples(cases, (valid, position) -> "double"); + return parameterSuppliersFromTypedData(cases); } @@ -47,22 +59,51 @@ private static void doubles(List cases) { cases.add(new TestCaseSupplier("most common scenario", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { List field = randomList(1, 10, () -> randomDoubleBetween(1, 10, false)); double p = randomDoubleBetween(-10, 10, true); - double expectedResult = calcPSeriesWeightedSum(field, p); - - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() - ), - "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", - DataType.DOUBLE, - match(expectedResult) - ); + return testCase(field, p); + })); + + cases.add(new TestCaseSupplier("values between 0 and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(0, 1, true)); + double p = randomDoubleBetween(-10, 10, true); + return testCase(field, p); + })); + + cases.add(new TestCaseSupplier("values between -1 and 0", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(-1, 0, true)); + double p = randomDoubleBetween(-10, 10, true); + return testCase(field, p); + })); + + cases.add(new TestCaseSupplier("values between 1 and Double.MAX_VALUE", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(1, Double.MAX_VALUE, true)); + double p = randomDoubleBetween(-10, 10, true); + return testCase(field, p); + })); + + cases.add(new TestCaseSupplier("values between -Double.MAX_VALUE and 1", List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDoubleBetween(-Double.MAX_VALUE, 1, true)); + double p = randomDoubleBetween(-10, 10, true); + return testCase(field, p); })); } - private static Matcher match(Double value) { - return closeTo(value, Math.abs(value * .00000001)); + private static TestCaseSupplier.TestCase testCase(List field, double p) { + double expectedResult = calcPSeriesWeightedSum(field, p); + + TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + Double.isFinite(expectedResult) ? closeTo(expectedResult, Math.abs(expectedResult * .00000001)) : nullValue() + ); + if (Double.isFinite(expectedResult) == false) { + return testCase.withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") + .withWarning("Line -1:-1: java.lang.ArithmeticException: double overflow"); + } + return testCase; } private static double calcPSeriesWeightedSum(List field, double p) {