From f36e6109ddc805a3ff16a06b0279c9af7ce3c6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 19 Aug 2024 15:56:11 +0200 Subject: [PATCH 01/23] Use ascending on MvMedian doubles --- .../multivalue/MvMedianDoubleEvaluator.java | 46 +++++++++++++++++++ .../function/scalar/multivalue/MvMedian.java | 4 +- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java index c3ea505a29e88..e3b539d8210aa 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java @@ -32,6 +32,9 @@ public String name() { */ @Override public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { @@ -60,6 +63,9 @@ public Block evalNullable(Block fieldVal) { */ @Override public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { @@ -79,6 +85,46 @@ public Block evalNotNullable(Block fieldVal) { } } + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedian.Doubles work = new MvMedian.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + double result = MvMedian.ascending(v, first, valueCount); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedian.Doubles work = new MvMedian.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + double result = MvMedian.ascending(v, first, valueCount); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { private final EvalOperator.ExpressionEvaluator.Factory field; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java index e9e6899117805..42510a5685de3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java @@ -36,7 +36,7 @@ import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.unsignedLongToBigInteger; /** - * Reduce a multivalued field to a single valued field containing the average value. + * Reduce a multivalued field to a single valued field containing the median of the values. */ public class MvMedian extends AbstractMultivalueFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "MvMedian", MvMedian::new); @@ -106,7 +106,7 @@ static class Doubles { public int count; } - @MvEvaluator(extraName = "Double", finish = "finish") + @MvEvaluator(extraName = "Double", finish = "finish", ascending = "ascending") static void process(Doubles doubles, double v) { if (doubles.values.length < doubles.count + 1) { doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); From a1b958d495cbacaa5df12987b0d3506e98af37f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 19 Aug 2024 19:18:16 +0200 Subject: [PATCH 02/23] Initial function and tests --- .../compute/gen/MvEvaluatorImplementer.java | 35 +- ...edianAbsoluteDeviationDoubleEvaluator.java | 157 ++++++++ ...MvMedianAbsoluteDeviationIntEvaluator.java | 203 +++++++++++ ...vMedianAbsoluteDeviationLongEvaluator.java | 203 +++++++++++ ...bsoluteDeviationUnsignedLongEvaluator.java | 203 +++++++++++ .../AbstractMultivalueFunction.java | 1 + .../multivalue/MvMedianAbsoluteDeviation.java | 340 ++++++++++++++++++ .../MvMedianAbsoluteDeviationTests.java | 113 ++++++ 8 files changed, 1247 insertions(+), 8 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java index 993b8363fb35f..d6e062facdbfd 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java @@ -107,7 +107,7 @@ public MvEvaluatorImplementer( this.resultType = TypeName.get(processFunction.getReturnType()); } this.singleValueFunction = SingleValueFunction.from(declarationType, singleValueMethodName, resultType, fieldType); - this.ascendingFunction = AscendingFunction.from(this, declarationType, ascendingMethodName); + this.ascendingFunction = AscendingFunction.from(this, declarationType, workType, ascendingMethodName); this.warnExceptions = warnExceptions; this.implementation = ClassName.get( elements.getPackageOf(declarationType).toString(), @@ -511,7 +511,7 @@ private void call(MethodSpec.Builder builder) { * Function handling blocks of ascending values. */ private class AscendingFunction { - static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarationType, String name) { + static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarationType, TypeName workType, String name) { if (name.equals("")) { return null; } @@ -523,8 +523,9 @@ static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarati m -> m.getParameters().size() == 1 && m.getParameters().get(0).asType().getKind() == TypeKind.INT ); if (fn != null) { - return impl.new AscendingFunction(fn, false); + return impl.new AscendingFunction(fn, false, false); } + // Block mode without work parameter fn = findMethod( declarationType, new String[] { name }, @@ -532,17 +533,31 @@ static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarati && m.getParameters().get(1).asType().getKind() == TypeKind.INT && m.getParameters().get(2).asType().getKind() == TypeKind.INT ); - if (fn == null) { - throw new IllegalArgumentException("Couldn't find " + declarationType + "#" + name + "(block, int, int)"); + if (fn != null) { + return impl.new AscendingFunction(fn, true, false); } - return impl.new AscendingFunction(fn, true); + // Block mode with work parameter + fn = findMethod( + declarationType, + new String[] { name }, + m -> m.getParameters().size() == 4 + && TypeName.get(m.getParameters().get(0).asType()).equals(workType) + && m.getParameters().get(2).asType().getKind() == TypeKind.INT + && m.getParameters().get(3).asType().getKind() == TypeKind.INT + ); + if (fn != null) { + return impl.new AscendingFunction(fn, true, true); + } + throw new IllegalArgumentException("Couldn't find " + declarationType + "#" + name + "(block, int, int)"); } private final List invocationArgs = new ArrayList<>(); private final boolean blockMode; + private final boolean withWorkParameter; - private AscendingFunction(ExecutableElement fn, boolean blockMode) { + private AscendingFunction(ExecutableElement fn, boolean blockMode, boolean withWorkParameter) { this.blockMode = blockMode; + this.withWorkParameter = withWorkParameter; if (blockMode) { invocationArgs.add(resultType); } @@ -552,7 +567,11 @@ private AscendingFunction(ExecutableElement fn, boolean blockMode) { private void call(MethodSpec.Builder builder) { if (blockMode) { - builder.addStatement("$T result = $T.$L(v, first, valueCount)", invocationArgs.toArray()); + if (withWorkParameter) { + builder.addStatement("$T result = $T.$L(work, v, first, valueCount)", invocationArgs.toArray()); + } else { + builder.addStatement("$T result = $T.$L(v, first, valueCount)", invocationArgs.toArray()); + } } else { builder.addStatement("int idx = $T.$L(valueCount)", invocationArgs.toArray()); fetch(builder, "result", resultType, "first + idx", workType.equals(fieldType) ? "firstScratch" : "valueScratch"); diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java new file mode 100644 index 0000000000000..f576e02b142b1 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java @@ -0,0 +1,157 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationDoubleEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationDoubleEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + double value = v.getDouble(i); + MvMedianAbsoluteDeviation.process(work, value); + } + double result = MvMedianAbsoluteDeviation.finish(work); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + double value = v.getDouble(i); + MvMedianAbsoluteDeviation.process(work, value); + } + double result = MvMedianAbsoluteDeviation.finish(work); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationDoubleEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationDoubleEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java new file mode 100644 index 0000000000000..6d33c49459037 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationIntEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationIntEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + int value = v.getInt(i); + MvMedianAbsoluteDeviation.process(work, value); + } + int result = MvMedianAbsoluteDeviation.finish(work); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + int value = v.getInt(i); + MvMedianAbsoluteDeviation.process(work, value); + } + int result = MvMedianAbsoluteDeviation.finish(work); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + int result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + int result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationIntEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationIntEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java new file mode 100644 index 0000000000000..e7883d92708b7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationLongEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationLongEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.process(work, value); + } + long result = MvMedianAbsoluteDeviation.finish(work); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.process(work, value); + } + long result = MvMedianAbsoluteDeviation.finish(work); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationLongEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationLongEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java new file mode 100644 index 0000000000000..ef8781e1dc048 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationUnsignedLongEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationUnsignedLongEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.processUnsignedLong(work, value); + } + long result = MvMedianAbsoluteDeviation.finishUnsignedLong(work); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.processUnsignedLong(work, value); + } + long result = MvMedianAbsoluteDeviation.finishUnsignedLong(work); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascendingUnsignedLong(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascendingUnsignedLong(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationUnsignedLongEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationUnsignedLongEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java index 90810d282ca52..5ce59778c5aae 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java @@ -43,6 +43,7 @@ public static List getNamedWriteables() { MvLast.ENTRY, MvMax.ENTRY, MvMedian.ENTRY, + MvMedianAbsoluteDeviation.ENTRY, MvMin.ENTRY, MvPSeriesWeightedSum.ENTRY, MvSlice.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java new file mode 100644 index 0000000000000..183791be7c7d2 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.apache.lucene.util.ArrayUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.ann.MvEvaluator; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.NumericUtils; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; + +import java.io.IOException; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.isRepresentable; +import static org.elasticsearch.xpack.esql.core.util.NumericUtils.unsignedLongSubtractExact; +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.bigIntegerToUnsignedLong; +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.unsignedLongToBigInteger; + +/** + * Reduce a multivalued field to a single valued field containing the median absolute deviation of the values. + */ +public class MvMedianAbsoluteDeviation extends AbstractMultivalueFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MvMedianAbsoluteDeviation", + MvMedianAbsoluteDeviation::new + ); + + @FunctionInfo( + returnType = { "double", "integer", "long", "unsigned_long" }, + description = "Converts a multivalued field into a single valued field containing the median absolute deviation.", + examples = { + @Example(file = "math", tag = "mv_median_absolute_deviation"), + @Example( + description = "If the field has an even number of values, " + + "the medians will be calculated as the average of the middle two values. " + + "If the column is not floating point, the average rounds *down*.", + file = "math", + tag = "mv_median_absolute_deviation_round_down" + ) } + ) + public MvMedianAbsoluteDeviation( + Source source, + @Param( + name = "number", + type = { "double", "integer", "long", "unsigned_long" }, + description = "Multivalue expression." + ) Expression field + ) { + super(source, field); + } + + private MvMedianAbsoluteDeviation(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveFieldType() { + return isType(field(), t -> t.isNumeric() && isRepresentable(t), sourceText(), null, "numeric"); + } + + @Override + protected ExpressionEvaluator.Factory evaluator(ExpressionEvaluator.Factory fieldEval) { + return switch (PlannerUtils.toElementType(field().dataType())) { + case DOUBLE -> new MvMedianAbsoluteDeviationDoubleEvaluator.Factory(fieldEval); + case INT -> new MvMedianAbsoluteDeviationIntEvaluator.Factory(fieldEval); + case LONG -> field().dataType() == DataType.UNSIGNED_LONG + ? new MvMedianAbsoluteDeviationUnsignedLongEvaluator.Factory(fieldEval) + : new MvMedianAbsoluteDeviationLongEvaluator.Factory(fieldEval); + default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType()); + }; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new MvMedianAbsoluteDeviation(source(), newChildren.get(0)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, MvMedianAbsoluteDeviation::new, field()); + } + + static class Doubles { + public double[] values = new double[2]; + public int count; + } + + @MvEvaluator(extraName = "Double", finish = "finish", single = "single") + static void process(Doubles doubles, double v) { + if (doubles.values.length < doubles.count + 1) { + doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); + } + doubles.values[doubles.count++] = v; + } + + static double finish(Doubles doubles) { + double median = doubleMedianOf(doubles.values, doubles.count); + for (int i = 0; i < doubles.count; i++) { + double value = doubles.values[i]; + doubles.values[i] = value > median ? value - median : median - value; + } + double mad = doubleMedianOf(doubles.values, doubles.count); + doubles.count = 0; + return mad; + } + + static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { + if (doubles.values.length < count) { + doubles.values = ArrayUtil.grow(doubles.values, count); + } + int middle = firstValue + count / 2; + double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) + values.getDouble(middle)) / 2; + for (int i = 0; i < count; i++) { + double value = values.getDouble(firstValue + i); + doubles.values[i] = value > median ? value - median : median - value; + } + double mad = doubleMedianOf(doubles.values, count); + return mad; + } + + static double single(double value) { + return 0.; + } + + static double doubleMedianOf(double[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + return count % 2 == 1 ? values[middle] : (values[middle - 1] + values[middle]) / 2; + } + + static class Longs { + public long[] values = new long[2]; + public int count; + } + + @MvEvaluator(extraName = "Long", finish = "finish", ascending = "ascending", single = "single") + static void process(Longs longs, long v) { + if (longs.values.length < longs.count + 1) { + longs.values = ArrayUtil.grow(longs.values, longs.count + 1); + } + longs.values[longs.count++] = v; + } + + static long finish(Longs longs) { + long median = longMedianOf(longs.values, longs.count); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + longs.values[i] = value > median ? value - median : median - value; + } + long mad = longMedianOf(longs.values, longs.count); + longs.count = 0; + return mad; + } + + /** + * If the values are ascending pick the middle value or average the two middle values together. + */ + static long ascending(Longs longs, LongBlock values, int firstValue, int count) { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + int middle = firstValue + count / 2; + long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + longs.values[i] = value > median ? value - median : median - value; + } + long mad = longMedianOf(longs.values, count); + return mad; + } + + static long single(long value) { + return 0L; + } + + static long longMedianOf(long[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); + } + + /** + * Average two {@code long}s without any overflow. + */ + static long avgWithoutOverflow(long a, long b) { + return (a & b) + ((a ^ b) >> 1); + } + + @MvEvaluator( + extraName = "UnsignedLong", + finish = "finishUnsignedLong", + ascending = "ascendingUnsignedLong", + single = "singleUnsignedLong" + ) + static void processUnsignedLong(Longs longs, long v) { + process(longs, v); + } + + static long finishUnsignedLong(Longs longs) { + long median = unsignedLongMedianOf(longs.values, longs.count); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + longs.values[i] = value > median + ? unsignedLongSubtractExact(value, median) + : unsignedLongSubtractExact(median, value); + } + long mad = unsignedLongMedianOf(longs.values, longs.count); + longs.count = 0; + return mad; + } + + /** + * If the values are ascending pick the middle value or average the two middle values together. + */ + static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) { + if (longs.values.length < longs.count + 1) { + longs.values = ArrayUtil.grow(longs.values, longs.count + 1); + } + int middle = firstValue + count / 2; + long median; + if (count % 2 == 1) { + median = values.getLong(middle); + } else { + BigInteger a = unsignedLongToBigInteger(values.getLong(middle - 1)); + BigInteger b = unsignedLongToBigInteger(values.getLong(middle)); + median = bigIntegerToUnsignedLong(a.add(b).shiftRight(1)); + } + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + } + long mad = unsignedLongMedianOf(longs.values, longs.count); + longs.count = 0; + return mad; + } + + static long singleUnsignedLong(long value) { + return NumericUtils.ZERO_AS_UNSIGNED_LONG; + } + + static long unsignedLongMedianOf(long[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + if (count % 2 == 1) { + return values[middle]; + } + BigInteger a = unsignedLongToBigInteger(values[middle - 1]); + BigInteger b = unsignedLongToBigInteger(values[middle]); + return bigIntegerToUnsignedLong(a.add(b).shiftRight(1)); + } + + static class Ints { + public int[] values = new int[2]; + public int count; + } + + @MvEvaluator(extraName = "Int", finish = "finish", ascending = "ascending", single = "single") + static void process(Ints ints, int v) { + if (ints.values.length < ints.count + 1) { + ints.values = ArrayUtil.grow(ints.values, ints.count + 1); + } + ints.values[ints.count++] = v; + } + + static int finish(Ints ints) { + int median = intMedianOf(ints.values, ints.count); + for (int i = 0; i < ints.count; i++) { + int value = ints.values[i]; + ints.values[i] = value > median ? value - median : median - value; + } + int mad = intMedianOf(ints.values, ints.count); + ints.count = 0; + return mad; + } + + /** + * If the values are ascending pick the middle value or average the two middle values together. + */ + static int ascending(Ints ints, IntBlock values, int firstValue, int count) { + if (ints.values.length < count) { + ints.values = ArrayUtil.grow(ints.values, count); + } + int middle = firstValue + count / 2; + int median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); + for (int i = 0; i < count; i++) { + int value = values.getInt(firstValue + i); + ints.values[i] = value > median ? value - median : median - value; + } + int mad = intMedianOf(ints.values, count); + return mad; + } + + static int single(int value) { + return 0; + } + + static int intMedianOf(int[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); + } + + /** + * Average two {@code int}s together without overflow. + */ + static int avgWithoutOverflow(int a, int b) { + return (a & b) + ((a ^ b) >> 1); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java new file mode 100644 index 0000000000000..5e899fa770fc2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MvMedianAbsoluteDeviationTests extends AbstractMultivalueFunctionTestCase { + public MvMedianAbsoluteDeviationTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List cases = new ArrayList<>(); + doubles(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, valuesStream) -> { + var values = valuesStream.sorted().toArray(); + int middle = size / 2; + if (size % 2 == 1) { + double median = values[middle]; + var a = Arrays.stream(values).map(d -> Math.abs(d - median)).toArray(); + var b = Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().toArray(); + var c = Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle).findFirst().orElseThrow(); + return equalTo(Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle).findFirst().orElseThrow()); + } else { + double median = (values[middle - 1] + values[middle]) / 2; + return equalTo( + Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle - 1).limit(2).average().orElseThrow() + ); + } + }); + ints(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values.mapToObj(BigInteger::valueOf)); + return equalTo(mad.intValue()); + }); + longs(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values.mapToObj(BigInteger::valueOf)); + return equalTo(mad.longValue()); + }); + unsignedLongs(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values); + return equalTo(mad); + }); + + cases.add( + new TestCaseSupplier( + "mv_median_absolute_deviation(<1, 2>)", + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(1, 2), DataType.INTEGER, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(0) + ) + ) + ); + cases.add( + new TestCaseSupplier( + "mv_median_absolute_deviation(<-1, -2>)", + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(-1, -2), DataType.INTEGER, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(0) + ) + ) + ); + return parameterSuppliersFromTypedDataWithDefaultChecks(false, cases, (v, p) -> "numeric"); + } + + private static BigInteger calculateMedianAbsoluteDeviation(int size, Stream valuesStream) { + var values = valuesStream.sorted().toArray(BigInteger[]::new); + int middle = size / 2; + if (size % 2 == 1) { + var median = values[middle]; + return Arrays.stream(values).map(bi -> bi.subtract(median).abs()).sorted().skip(middle).findFirst().orElseThrow(); + } else { + var median = values[middle - 1].add(values[middle]).divide(BigInteger.valueOf(2)); + return Arrays.stream(values) + .map(bi -> bi.subtract(median).abs()) + .sorted() + .skip(middle - 1) + .limit(2) + .reduce(BigInteger.ZERO, BigInteger::add) + .divide(BigInteger.valueOf(2)); + } + } + + @Override + protected Expression build(Source source, Expression field) { + return new MvMedianAbsoluteDeviation(source, field); + } +} From ad318479cc898547a86fad17e31ae1d8af6868dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 20 Aug 2024 12:56:37 +0200 Subject: [PATCH 03/23] Fixed unsigned long ascending --- .../scalar/multivalue/MvMedianAbsoluteDeviation.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 183791be7c7d2..47b67d32736ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -242,8 +242,8 @@ static long finishUnsignedLong(Longs longs) { * If the values are ascending pick the middle value or average the two middle values together. */ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) { - if (longs.values.length < longs.count + 1) { - longs.values = ArrayUtil.grow(longs.values, longs.count + 1); + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); } int middle = firstValue + count / 2; long median; @@ -258,8 +258,7 @@ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, long value = values.getLong(firstValue + i); longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); } - long mad = unsignedLongMedianOf(longs.values, longs.count); - longs.count = 0; + long mad = unsignedLongMedianOf(longs.values, count); return mad; } From adfab0505d6d2c6296427cd09fe3f1f654bcebf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 20 Aug 2024 18:17:53 +0200 Subject: [PATCH 04/23] Added overflow test cases, and double overflow fixes --- .../multivalue/MvMedianAbsoluteDeviation.java | 6 +- .../MvMedianAbsoluteDeviationTests.java | 138 +++++++++++++++--- 2 files changed, 122 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 47b67d32736ba..ccfebc3876e12 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -153,7 +153,7 @@ static double doubleMedianOf(double[] values, int count) { // TODO quickselect Arrays.sort(values, 0, count); int middle = count / 2; - return count % 2 == 1 ? values[middle] : (values[middle - 1] + values[middle]) / 2; + return count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); } static class Longs { @@ -229,9 +229,7 @@ static long finishUnsignedLong(Longs longs) { long median = unsignedLongMedianOf(longs.values, longs.count); for (int i = 0; i < longs.count; i++) { long value = longs.values[i]; - longs.values[i] = value > median - ? unsignedLongSubtractExact(value, median) - : unsignedLongSubtractExact(median, value); + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); } long mad = unsignedLongMedianOf(longs.values, longs.count); longs.count = 0; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java index 5e899fa770fc2..ea6e4c903472c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -13,6 +13,8 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; +import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import java.math.BigInteger; @@ -37,9 +39,6 @@ public static Iterable parameters() { int middle = size / 2; if (size % 2 == 1) { double median = values[middle]; - var a = Arrays.stream(values).map(d -> Math.abs(d - median)).toArray(); - var b = Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().toArray(); - var c = Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle).findFirst().orElseThrow(); return equalTo(Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle).findFirst().orElseThrow()); } else { double median = (values[middle - 1] + values[middle]) / 2; @@ -61,31 +60,134 @@ public static Iterable parameters() { return equalTo(mad); }); - cases.add( + cases.addAll( + List.of( + // Simple cases + new TestCaseSupplier( + "mv_median_absolute_deviation(<1, 2>)", + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(1, 2), DataType.INTEGER, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(0) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation(<-1, -2>)", + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(-1, -2), DataType.INTEGER, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(0) + ) + ) + ) + ); + + cases.addAll( + overflowCasesFor( + DataType.INTEGER, + Integer.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Integer.MAX_VALUE / 2, + Integer.MAX_VALUE / 2 + 1 + ) + ); + cases.addAll( + overflowCasesFor(DataType.LONG, Long.MAX_VALUE, Long.MIN_VALUE, Long.MAX_VALUE, Long.MAX_VALUE / 2L, Long.MAX_VALUE / 2L + 1) + ); + cases.addAll( + overflowCasesFor( + DataType.DOUBLE, + Double.MAX_VALUE, + -Double.MAX_VALUE, + Double.MAX_VALUE, + Double.MAX_VALUE / 2., + Double.MAX_VALUE / 2. + ) + ); + cases.addAll( + overflowCasesFor( + DataType.UNSIGNED_LONG, + NumericUtils.asLongUnsigned(NumericUtils.UNSIGNED_LONG_MAX), + NumericUtils.ZERO_AS_UNSIGNED_LONG, + NumericUtils.UNSIGNED_LONG_MAX.divide(BigInteger.valueOf(2)), + NumericUtils.UNSIGNED_LONG_MAX.divide(BigInteger.valueOf(2)), + BigInteger.ZERO + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(false, cases, (v, p) -> "numeric"); + } + + private static List overflowCasesFor( + DataType type, + Number max, + Number min, + Number maxMinMad, + Number maxZeroMad, + Number minZeroMad + ) { + var zeroExpected = DataTypeConverter.convert(0, type); + var zeroValue = type == DataType.UNSIGNED_LONG ? NumericUtils.ZERO_AS_UNSIGNED_LONG : zeroExpected; + + var typeName = type.name().toLowerCase(); + + return List.of( new TestCaseSupplier( - "mv_median_absolute_deviation(<1, 2>)", - List.of(DataType.INTEGER), + "mv_median_absolute_deviation()", + List.of(type), () -> new TestCaseSupplier.TestCase( - List.of(new TestCaseSupplier.TypedData(List.of(1, 2), DataType.INTEGER, "field")), + List.of(new TestCaseSupplier.TypedData(List.of(max, min), type, "field")), "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(0) + type, + equalTo(maxMinMad) ) - ) - ); - cases.add( + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(max, zeroValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(maxZeroMad) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(min, zeroValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(minZeroMad) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(max, max), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(zeroExpected) + ) + ), new TestCaseSupplier( - "mv_median_absolute_deviation(<-1, -2>)", - List.of(DataType.INTEGER), + "mv_median_absolute_deviation()", + List.of(type), () -> new TestCaseSupplier.TestCase( - List.of(new TestCaseSupplier.TypedData(List.of(-1, -2), DataType.INTEGER, "field")), + List.of(new TestCaseSupplier.TypedData(List.of(min, min), type, "field")), "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(0) + type, + equalTo(zeroExpected) ) ) ); - return parameterSuppliersFromTypedDataWithDefaultChecks(false, cases, (v, p) -> "numeric"); } private static BigInteger calculateMedianAbsoluteDeviation(int size, Stream valuesStream) { From 1c4c381c10ff701a945f149fd09b83335b94b2d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 21 Aug 2024 11:18:59 +0200 Subject: [PATCH 05/23] Added overflow checks and extra tests for them --- .../multivalue/MvMedianAbsoluteDeviation.java | 10 +++++--- .../MvMedianAbsoluteDeviationTests.java | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index ccfebc3876e12..43fa93d9ac96a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -125,6 +125,7 @@ static double finish(Doubles doubles) { for (int i = 0; i < doubles.count; i++) { double value = doubles.values[i]; doubles.values[i] = value > median ? value - median : median - value; + assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; } double mad = doubleMedianOf(doubles.values, doubles.count); doubles.count = 0; @@ -140,6 +141,7 @@ static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int for (int i = 0; i < count; i++) { double value = values.getDouble(firstValue + i); doubles.values[i] = value > median ? value - median : median - value; + assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; } double mad = doubleMedianOf(doubles.values, count); return mad; @@ -173,7 +175,7 @@ static long finish(Longs longs) { long median = longMedianOf(longs.values, longs.count); for (int i = 0; i < longs.count; i++) { long value = longs.values[i]; - longs.values[i] = value > median ? value - median : median - value; + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } long mad = longMedianOf(longs.values, longs.count); longs.count = 0; @@ -191,7 +193,7 @@ static long ascending(Longs longs, LongBlock values, int firstValue, int count) long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); for (int i = 0; i < count; i++) { long value = values.getLong(firstValue + i); - longs.values[i] = value > median ? value - median : median - value; + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } long mad = longMedianOf(longs.values, count); return mad; @@ -293,7 +295,7 @@ static int finish(Ints ints) { int median = intMedianOf(ints.values, ints.count); for (int i = 0; i < ints.count; i++) { int value = ints.values[i]; - ints.values[i] = value > median ? value - median : median - value; + ints.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } int mad = intMedianOf(ints.values, ints.count); ints.count = 0; @@ -311,7 +313,7 @@ static int ascending(Ints ints, IntBlock values, int firstValue, int count) { int median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); for (int i = 0; i < count; i++) { int value = values.getInt(firstValue + i); - ints.values[i] = value > median ? value - median : median - value; + ints.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } int mad = intMedianOf(ints.values, count); return mad; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java index ea6e4c903472c..44a8f34d8dc31 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -82,6 +82,18 @@ public static Iterable parameters() { DataType.INTEGER, equalTo(0) ) + ), + + // Custom double overflow, as "MAX_DOUBLE + 1000 == MAX_DOUBLE" + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(-Double.MAX_VALUE, Double.MAX_VALUE/4, Double.MAX_VALUE/4), DataType.DOUBLE, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(0.) + ) ) ) ); @@ -133,6 +145,9 @@ private static List overflowCasesFor( ) { var zeroExpected = DataTypeConverter.convert(0, type); var zeroValue = type == DataType.UNSIGNED_LONG ? NumericUtils.ZERO_AS_UNSIGNED_LONG : zeroExpected; + var oneThousandValue = type == DataType.UNSIGNED_LONG + ? NumericUtils.asLongUnsigned(BigInteger.valueOf(1000)) + : DataTypeConverter.convert(1000, type); var typeName = type.name().toLowerCase(); @@ -186,6 +201,16 @@ private static List overflowCasesFor( type, equalTo(zeroExpected) ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(min, oneThousandValue, oneThousandValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(zeroExpected) + ) ) ); } From 990ef6c3e4a22e86623940a4baf4f6e256c6d5ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 21 Aug 2024 11:55:58 +0200 Subject: [PATCH 06/23] Minor refactor --- .../scalar/multivalue/MvMedianAbsoluteDeviation.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 43fa93d9ac96a..513573b4ff0b0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -143,8 +143,7 @@ static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int doubles.values[i] = value > median ? value - median : median - value; assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; } - double mad = doubleMedianOf(doubles.values, count); - return mad; + return doubleMedianOf(doubles.values, count); } static double single(double value) { @@ -195,8 +194,7 @@ static long ascending(Longs longs, LongBlock values, int firstValue, int count) long value = values.getLong(firstValue + i); longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } - long mad = longMedianOf(longs.values, count); - return mad; + return longMedianOf(longs.values, count); } static long single(long value) { @@ -258,8 +256,7 @@ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, long value = values.getLong(firstValue + i); longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); } - long mad = unsignedLongMedianOf(longs.values, count); - return mad; + return unsignedLongMedianOf(longs.values, count); } static long singleUnsignedLong(long value) { @@ -315,8 +312,7 @@ static int ascending(Ints ints, IntBlock values, int firstValue, int count) { int value = values.getInt(firstValue + i); ints.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } - int mad = intMedianOf(ints.values, count); - return mad; + return intMedianOf(ints.values, count); } static int single(int value) { From c5cbd0ccc6da91a99d0abc82d60079e33615e4ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 21 Aug 2024 11:57:56 +0200 Subject: [PATCH 07/23] Update docs/changelog/112055.yaml --- docs/changelog/112055.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/112055.yaml diff --git a/docs/changelog/112055.yaml b/docs/changelog/112055.yaml new file mode 100644 index 0000000000000..cdf15b3b37468 --- /dev/null +++ b/docs/changelog/112055.yaml @@ -0,0 +1,6 @@ +pr: 112055 +summary: "ESQL: `mv_median_absolute_deviation` function" +area: ES|QL +type: feature +issues: + - 111590 From 25f05ccb273e94e130e41074ee5ecd6fdcdc4b17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 21 Aug 2024 12:26:12 +0200 Subject: [PATCH 08/23] Format --- .../scalar/multivalue/MvMedianAbsoluteDeviationTests.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java index 44a8f34d8dc31..7cb6e55594cfe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -89,7 +89,13 @@ public static Iterable parameters() { "mv_median_absolute_deviation()", List.of(DataType.DOUBLE), () -> new TestCaseSupplier.TestCase( - List.of(new TestCaseSupplier.TypedData(List.of(-Double.MAX_VALUE, Double.MAX_VALUE/4, Double.MAX_VALUE/4), DataType.DOUBLE, "field")), + List.of( + new TestCaseSupplier.TypedData( + List.of(-Double.MAX_VALUE, Double.MAX_VALUE / 4, Double.MAX_VALUE / 4), + DataType.DOUBLE, + "field" + ) + ), "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", DataType.DOUBLE, equalTo(0.) From 557c62ffafebc95268b6370b78e947c17a08f2ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 14:09:58 +0200 Subject: [PATCH 09/23] Fix overflows by using the next bigger number type for MAD (in -> long, long -> unsigned long) --- .../multivalue/MvMedianAbsoluteDeviation.java | 203 ++++++++++-------- 1 file changed, 112 insertions(+), 91 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 513573b4ff0b0..f1fd3a1595b5f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -107,59 +107,49 @@ protected NodeInfo info() { return NodeInfo.create(this, MvMedianAbsoluteDeviation::new, field()); } - static class Doubles { - public double[] values = new double[2]; + static class Longs { + public long[] values = new long[2]; public int count; } - @MvEvaluator(extraName = "Double", finish = "finish", single = "single") - static void process(Doubles doubles, double v) { - if (doubles.values.length < doubles.count + 1) { - doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); + @MvEvaluator(extraName = "Int", finish = "finishInts", ascending = "ascending", single = "single") + static void process(Longs longs, int v) { + if (longs.values.length < longs.count + 1) { + longs.values = ArrayUtil.grow(longs.values, longs.count + 1); } - doubles.values[doubles.count++] = v; + longs.values[longs.count++] = v; } - static double finish(Doubles doubles) { - double median = doubleMedianOf(doubles.values, doubles.count); - for (int i = 0; i < doubles.count; i++) { - double value = doubles.values[i]; - doubles.values[i] = value > median ? value - median : median - value; - assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; + static int finishInts(Longs longs) { + long median = longMedianOf(longs.values, longs.count); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + // We know they were ints, so we can calculate differences within a long + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } - double mad = doubleMedianOf(doubles.values, doubles.count); - doubles.count = 0; + int mad = Math.toIntExact(longMedianOf(longs.values, longs.count)); + longs.count = 0; return mad; } - static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { - if (doubles.values.length < count) { - doubles.values = ArrayUtil.grow(doubles.values, count); + /** + * If the values are ascending pick the middle value or average the two middle values together. + */ + static int ascending(Longs longs, IntBlock values, int firstValue, int count) { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); } int middle = firstValue + count / 2; - double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) + values.getDouble(middle)) / 2; + long median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); for (int i = 0; i < count; i++) { - double value = values.getDouble(firstValue + i); - doubles.values[i] = value > median ? value - median : median - value; - assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; + long value = values.getInt(firstValue + i); + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); } - return doubleMedianOf(doubles.values, count); - } - - static double single(double value) { - return 0.; - } - - static double doubleMedianOf(double[] values, int count) { - // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - return count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); + return Math.toIntExact(longMedianOf(longs.values, count)); } - static class Longs { - public long[] values = new long[2]; - public int count; + static int single(int value) { + return 0; } @MvEvaluator(extraName = "Long", finish = "finish", ascending = "ascending", single = "single") @@ -174,9 +164,10 @@ static long finish(Longs longs) { long median = longMedianOf(longs.values, longs.count); for (int i = 0; i < longs.count; i++) { long value = longs.values[i]; - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); } - long mad = longMedianOf(longs.values, longs.count); + long mad = NumericUtils.asLongUnsigned(unsignedLongMedianOf(longs.values, longs.count)); longs.count = 0; return mad; } @@ -192,9 +183,10 @@ static long ascending(Longs longs, LongBlock values, int firstValue, int count) long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); for (int i = 0; i < count; i++) { long value = values.getLong(firstValue + i); - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); } - return longMedianOf(longs.values, count); + return NumericUtils.asLongUnsigned(unsignedLongMedianOf(longs.values, count)); } static long single(long value) { @@ -208,11 +200,55 @@ static long longMedianOf(long[] values, int count) { return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); } - /** - * Average two {@code long}s without any overflow. - */ - static long avgWithoutOverflow(long a, long b) { - return (a & b) + ((a ^ b) >> 1); + static class Doubles { + public double[] values = new double[2]; + public int count; + } + + @MvEvaluator(extraName = "Double", finish = "finish", single = "single") + static void process(Doubles doubles, double v) { + if (doubles.values.length < doubles.count + 1) { + doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); + } + doubles.values[doubles.count++] = v; + } + + static double finish(Doubles doubles) { + double median = doubleMedianOf(doubles.values, doubles.count); + for (int i = 0; i < doubles.count; i++) { + double value = doubles.values[i]; + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. + doubles.values[i] = value > median ? value - median : median - value; + } + double mad = doubleMedianOf(doubles.values, doubles.count); + doubles.count = 0; + return mad; + } + + static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { + if (doubles.values.length < count) { + doubles.values = ArrayUtil.grow(doubles.values, count); + } + int middle = firstValue + count / 2; + double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) + values.getDouble(middle)) / 2; + for (int i = 0; i < count; i++) { + double value = values.getDouble(firstValue + i); + doubles.values[i] = value > median ? value - median : median - value; + assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; + } + return doubleMedianOf(doubles.values, count); + } + + static double single(double value) { + return 0.; + } + + static double doubleMedianOf(double[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + return count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); } @MvEvaluator( @@ -275,61 +311,46 @@ static long unsignedLongMedianOf(long[] values, int count) { return bigIntegerToUnsignedLong(a.add(b).shiftRight(1)); } - static class Ints { - public int[] values = new int[2]; - public int count; - } + // Utility methods - @MvEvaluator(extraName = "Int", finish = "finish", ascending = "ascending", single = "single") - static void process(Ints ints, int v) { - if (ints.values.length < ints.count + 1) { - ints.values = ArrayUtil.grow(ints.values, ints.count + 1); - } - ints.values[ints.count++] = v; + /** + * Average two {@code int}s together without overflow. + */ + static int avgWithoutOverflow(int a, int b) { + var value = (a & b) + ((a ^ b) >> 1); + // This method rounds negative values down instead of towards zero, like (a + b) / 2 would. + // Here we rectify up if the average is negative and the two values have different parities. + return value < 0 && ((a & 1) ^ (b & 1)) == 1 ? value + 1 : value; } - static int finish(Ints ints) { - int median = intMedianOf(ints.values, ints.count); - for (int i = 0; i < ints.count; i++) { - int value = ints.values[i]; - ints.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); - } - int mad = intMedianOf(ints.values, ints.count); - ints.count = 0; - return mad; + /** + * Average two {@code long}s without any overflow. + */ + static long avgWithoutOverflow(long a, long b) { + var value = (a & b) + ((a ^ b) >> 1); + // This method rounds negative values down instead of towards zero, like (a + b) / 2 would. + // Here we rectify up if the average is negative and the two values have different parities. + return value < 0 && ((a & 1) ^ (b & 1)) == 1 ? value + 1 : value; } /** - * If the values are ascending pick the middle value or average the two middle values together. + * Returns the difference between two signed long values as an unsigned long. */ - static int ascending(Ints ints, IntBlock values, int firstValue, int count) { - if (ints.values.length < count) { - ints.values = ArrayUtil.grow(ints.values, count); - } - int middle = firstValue + count / 2; - int median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); - for (int i = 0; i < count; i++) { - int value = values.getInt(firstValue + i); - ints.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + static long unsignedDifference(long a, long b) { + if (a >= b) { + if (a < 0 || b >= 0) { + return NumericUtils.asLongUnsigned(a - b); + } + + return NumericUtils.unsignedLongSubtractExact(a, b); } - return intMedianOf(ints.values, count); - } - static int single(int value) { - return 0; - } + // Same operations, but inverted - static int intMedianOf(int[] values, int count) { - // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); - } + if (b < 0 || a >= 0) { + return NumericUtils.asLongUnsigned(b - a); + } - /** - * Average two {@code int}s together without overflow. - */ - static int avgWithoutOverflow(int a, int b) { - return (a & b) + ((a ^ b) >> 1); + return NumericUtils.unsignedLongSubtractExact(b, a); } } From 60b698bdbd43d8aba72ec5990e225beececa672b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 14:31:50 +0200 Subject: [PATCH 10/23] Ensure exact conversions in longs --- .../xpack/esql/core/util/NumericUtils.java | 12 ++++++++++++ .../multivalue/MvMedianAbsoluteDeviation.java | 16 ++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java index 3bff45db5023c..1f960f79c7901 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java @@ -88,6 +88,18 @@ public static Number unsignedLongAsNumber(long l) { return l < 0 ? twosComplement(l) : LONG_MAX_PLUS_ONE_AS_BIGINTEGER.add(BigInteger.valueOf(l)); } + /** + * Converts an unsigned long value "encoded" into a (signed) long. + * In case of overflow, an ArithmeticException is thrown. + */ + public static long unsignedLongAsLongExact(long l) { + if (l < 0) { + return twosComplement(l); + } + + throw new ArithmeticException(UNSIGNED_LONG_OVERFLOW); + } + public static BigInteger unsignedLongAsBigInteger(long l) { return l < 0 ? BigInteger.valueOf(twosComplement(l)) : LONG_MAX_PLUS_ONE_AS_BIGINTEGER.add(BigInteger.valueOf(l)); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index f1fd3a1595b5f..b61251018ef56 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -167,7 +167,7 @@ static long finish(Longs longs) { // From here, this array contains unsigned longs longs.values[i] = unsignedDifference(value, median); } - long mad = NumericUtils.asLongUnsigned(unsignedLongMedianOf(longs.values, longs.count)); + long mad = unsignedLongMedianOfAsLong(longs.values, longs.count); longs.count = 0; return mad; } @@ -186,7 +186,7 @@ static long ascending(Longs longs, LongBlock values, int firstValue, int count) // From here, this array contains unsigned longs longs.values[i] = unsignedDifference(value, median); } - return NumericUtils.asLongUnsigned(unsignedLongMedianOf(longs.values, count)); + return unsignedLongMedianOfAsLong(longs.values, count); } static long single(long value) { @@ -200,6 +200,18 @@ static long longMedianOf(long[] values, int count) { return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); } + static long unsignedLongMedianOfAsLong(long[] values, int count) { + // TODO quickselect + Arrays.sort(values, 0, count); + int middle = count / 2; + if (count % 2 == 1) { + return NumericUtils.unsignedLongAsLongExact(values[middle]); + } + BigInteger a = unsignedLongToBigInteger(values[middle - 1]); + BigInteger b = unsignedLongToBigInteger(values[middle]); + return a.add(b).shiftRight(1).longValueExact(); + } + static class Doubles { public double[] values = new double[2]; public int count; From 880f3108cd4200f272e3a4f8d6f336c4e2f00fb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 15:52:54 +0200 Subject: [PATCH 11/23] Avoid using BigIntegers for unsigned longs --- .../multivalue/MvMedianAbsoluteDeviation.java | 34 ++---- .../MvMedianAbsoluteDeviationTests.java | 102 ++++++++++-------- 2 files changed, 70 insertions(+), 66 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index b61251018ef56..fce5255734366 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -27,15 +27,12 @@ import org.elasticsearch.xpack.esql.planner.PlannerUtils; import java.io.IOException; -import java.math.BigInteger; import java.util.Arrays; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.isRepresentable; import static org.elasticsearch.xpack.esql.core.util.NumericUtils.unsignedLongSubtractExact; -import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.bigIntegerToUnsignedLong; -import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.unsignedLongToBigInteger; /** * Reduce a multivalued field to a single valued field containing the median absolute deviation of the values. @@ -167,7 +164,7 @@ static long finish(Longs longs) { // From here, this array contains unsigned longs longs.values[i] = unsignedDifference(value, median); } - long mad = unsignedLongMedianOfAsLong(longs.values, longs.count); + long mad = NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs.values, longs.count)); longs.count = 0; return mad; } @@ -186,7 +183,7 @@ static long ascending(Longs longs, LongBlock values, int firstValue, int count) // From here, this array contains unsigned longs longs.values[i] = unsignedDifference(value, median); } - return unsignedLongMedianOfAsLong(longs.values, count); + return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs.values, count)); } static long single(long value) { @@ -200,18 +197,6 @@ static long longMedianOf(long[] values, int count) { return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); } - static long unsignedLongMedianOfAsLong(long[] values, int count) { - // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - if (count % 2 == 1) { - return NumericUtils.unsignedLongAsLongExact(values[middle]); - } - BigInteger a = unsignedLongToBigInteger(values[middle - 1]); - BigInteger b = unsignedLongToBigInteger(values[middle]); - return a.add(b).shiftRight(1).longValueExact(); - } - static class Doubles { public double[] values = new double[2]; public int count; @@ -296,9 +281,7 @@ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, if (count % 2 == 1) { median = values.getLong(middle); } else { - BigInteger a = unsignedLongToBigInteger(values.getLong(middle - 1)); - BigInteger b = unsignedLongToBigInteger(values.getLong(middle)); - median = bigIntegerToUnsignedLong(a.add(b).shiftRight(1)); + median = unsignedLongAvgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); } for (int i = 0; i < count; i++) { long value = values.getLong(firstValue + i); @@ -318,9 +301,7 @@ static long unsignedLongMedianOf(long[] values, int count) { if (count % 2 == 1) { return values[middle]; } - BigInteger a = unsignedLongToBigInteger(values[middle - 1]); - BigInteger b = unsignedLongToBigInteger(values[middle]); - return bigIntegerToUnsignedLong(a.add(b).shiftRight(1)); + return unsignedLongAvgWithoutOverflow(values[middle - 1], values[middle]); } // Utility methods @@ -345,6 +326,13 @@ static long avgWithoutOverflow(long a, long b) { return value < 0 && ((a & 1) ^ (b & 1)) == 1 ? value + 1 : value; } + /** + * Average two {@code unsigned long}s without any overflow. + */ + static long unsignedLongAvgWithoutOverflow(long a, long b) { + return (a >> 1) + (b >> 1) + (a & b & 1); + } + /** * Returns the difference between two signed long values as an unsigned long. */ diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java index 7cb6e55594cfe..1209193c02caf 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -60,50 +60,13 @@ public static Iterable parameters() { return equalTo(mad); }); - cases.addAll( - List.of( - // Simple cases - new TestCaseSupplier( - "mv_median_absolute_deviation(<1, 2>)", - List.of(DataType.INTEGER), - () -> new TestCaseSupplier.TestCase( - List.of(new TestCaseSupplier.TypedData(List.of(1, 2), DataType.INTEGER, "field")), - "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(0) - ) - ), - new TestCaseSupplier( - "mv_median_absolute_deviation(<-1, -2>)", - List.of(DataType.INTEGER), - () -> new TestCaseSupplier.TestCase( - List.of(new TestCaseSupplier.TypedData(List.of(-1, -2), DataType.INTEGER, "field")), - "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", - DataType.INTEGER, - equalTo(0) - ) - ), - - // Custom double overflow, as "MAX_DOUBLE + 1000 == MAX_DOUBLE" - new TestCaseSupplier( - "mv_median_absolute_deviation()", - List.of(DataType.DOUBLE), - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData( - List.of(-Double.MAX_VALUE, Double.MAX_VALUE / 4, Double.MAX_VALUE / 4), - DataType.DOUBLE, - "field" - ) - ), - "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", - DataType.DOUBLE, - equalTo(0.) - ) - ) - ) - ); + // Simple cases + cases.addAll(makeCases(List.of(1, 2, 5), 1, true)); + cases.addAll(makeCases(List.of(1, 2), 0, false)); + cases.addAll(makeCases(List.of(-1, -2), 0, false)); + cases.addAll(makeCases(List.of(0, 2, 5, 6), 2, false)); + // Overflow cases cases.addAll( overflowCasesFor( DataType.INTEGER, @@ -138,9 +101,62 @@ public static Iterable parameters() { ) ); + // Custom double overflow. Can't be checked in the generic overflow cases, as "MAX_DOUBLE + 1000 == MAX_DOUBLE" + cases.add( + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData( + List.of(-Double.MAX_VALUE, Double.MAX_VALUE / 4, Double.MAX_VALUE / 4), + DataType.DOUBLE, + "field" + ) + ), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(0.) + ) + ) + ); + return parameterSuppliersFromTypedDataWithDefaultChecks(false, cases, (v, p) -> "numeric"); } + /** + * Makes cases for the given data, for each type + */ + private static List makeCases(List data, Number expectedResult, boolean withDoubles) { + var types = new ArrayList<>(List.of(DataType.INTEGER, DataType.LONG)); + if (withDoubles) { + types.add(DataType.DOUBLE); + } + if (data.stream().noneMatch(d -> d.doubleValue() < 0)) { + types.add(DataType.UNSIGNED_LONG); + } + return types.stream().map(type -> { + var convertedData = data.stream().map(d -> { + var convertedValue = DataTypeConverter.convert(d, type); + if (convertedValue instanceof BigInteger bi) { + return NumericUtils.asLongUnsigned(bi); + } + return convertedValue; + }).toList(); + + return new TestCaseSupplier( + "<" + convertedData + "> (" + type + ")", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(convertedData, type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(DataTypeConverter.convert(expectedResult, type)) + ) + ); + }).toList(); + } + private static List overflowCasesFor( DataType type, Number max, From 0da4f145d65ef5748ea91ad53827f3e06f459774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 16:18:57 +0200 Subject: [PATCH 12/23] Added function to registry, and updated failing tests --- .../src/main/resources/meta.csv-spec | 6 +++- ...MvMedianAbsoluteDeviationIntEvaluator.java | 16 +++++----- .../function/EsqlFunctionRegistry.java | 2 ++ ...anAbsoluteDeviationSerializationTests.java | 32 +++++++++++++++++++ .../MvMedianAbsoluteDeviationTests.java | 3 +- 5 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index f1f66a9cb990c..cc45a42a8e64c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -53,6 +53,7 @@ double e() "boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version mv_last(field:boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version)" "boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version mv_max(field:boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version)" "double|integer|long|unsigned_long mv_median(number:double|integer|long|unsigned_long)" +"double|integer|long|unsigned_long mv_median_absolute_deviation(number:double|integer|long|unsigned_long)" "boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version mv_min(field:boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version)" "double|integer|long mv_percentile(number:double|integer|long, percentile:double|integer|long)" "double mv_pseries_weighted_sum(number:double, p:double)" @@ -177,6 +178,7 @@ mv_first |field |"boolean|cartesian_point|car mv_last |field |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_max |field |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_median |number |"double|integer|long|unsigned_long" |Multivalue expression. +mv_median_abso|number |"double|integer|long|unsigned_long" |Multivalue expression. mv_min |field |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_percentile |[number, percentile] |["double|integer|long", "double|integer|long"] |[Multivalue expression., The percentile to calculate. Must be a number between 0 and 100. Numbers out of range will return a null instead.] mv_pseries_wei|[number, p] |[double, double] |[Multivalue expression., It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum.] @@ -301,6 +303,7 @@ mv_first |Converts a multivalued expression into a single valued column con mv_last |Converts a multivalue expression into a single valued column containing the last value. This is most useful when reading from a function that emits multivalued columns in a known order like <>. The order that <> are read from underlying storage is not guaranteed. It is *frequently* ascending, but don't rely on that. If you need the maximum value use <> instead of `MV_LAST`. `MV_MAX` has optimizations for sorted values so there isn't a performance benefit to `MV_LAST`. mv_max |Converts a multivalued expression into a single valued column containing the maximum value. mv_median |Converts a multivalued field into a single valued field containing the median value. +mv_median_abso|Converts a multivalued field into a single valued field containing the median absolute deviation. mv_min |Converts a multivalued expression into a single valued column containing the minimum value. mv_percentile |Converts a multivalued field into a single valued field containing the value at which a certain percentage of observed values occur. mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. @@ -427,6 +430,7 @@ mv_first |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|g mv_last |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version"|false |false |false mv_max |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false mv_median |"double|integer|long|unsigned_long" |false |false |false +mv_median_abso|"double|integer|long|unsigned_long" |false |false |false mv_min |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false mv_percentile |"double|integer|long" |[false, false] |false |false mv_pseries_wei|"double" |[false, false] |false |false @@ -508,5 +512,5 @@ countFunctions#[skip:-8.15.99] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long -115 | 115 | 115 +116 | 116 | 116 ; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java index 6d33c49459037..76013ca1115db 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java @@ -38,7 +38,7 @@ public Block evalNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); if (valueCount == 0) { @@ -57,7 +57,7 @@ public Block evalNullable(Block fieldVal) { int value = v.getInt(i); MvMedianAbsoluteDeviation.process(work, value); } - int result = MvMedianAbsoluteDeviation.finish(work); + int result = MvMedianAbsoluteDeviation.finishInts(work); builder.appendInt(result); } return builder.build(); @@ -75,7 +75,7 @@ public Block evalNotNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); int first = v.getFirstValueIndex(p); @@ -90,7 +90,7 @@ public Block evalNotNullable(Block fieldVal) { int value = v.getInt(i); MvMedianAbsoluteDeviation.process(work, value); } - int result = MvMedianAbsoluteDeviation.finish(work); + int result = MvMedianAbsoluteDeviation.finishInts(work); builder.appendInt(result); } return builder.build().asBlock(); @@ -105,7 +105,7 @@ public Block evalSingleValuedNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); if (valueCount == 0) { @@ -130,7 +130,7 @@ public Block evalSingleValuedNotNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); assert valueCount == 1; @@ -150,7 +150,7 @@ private Block evalAscendingNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); if (valueCount == 0) { @@ -172,7 +172,7 @@ private Block evalAscendingNotNullable(Block fieldVal) { IntBlock v = (IntBlock) fieldVal; int positionCount = v.getPositionCount(); try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { - MvMedianAbsoluteDeviation.Ints work = new MvMedianAbsoluteDeviation.Ints(); + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); for (int p = 0; p < positionCount; p++) { int valueCount = v.getValueCount(p); int first = v.getFirstValueIndex(p); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 0d50623fe77eb..99794dd5875f8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -94,6 +94,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvLast; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviation; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPSeriesWeightedSum; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPercentile; @@ -363,6 +364,7 @@ private FunctionDefinition[][] functions() { def(MvLast.class, MvLast::new, "mv_last"), def(MvMax.class, MvMax::new, "mv_max"), def(MvMedian.class, MvMedian::new, "mv_median"), + def(MvMedianAbsoluteDeviation.class, MvMedianAbsoluteDeviation::new, "mv_median_absolute_deviation"), def(MvMin.class, MvMin::new, "mv_min"), def(MvPercentile.class, MvPercentile::new, "mv_percentile"), def(MvPSeriesWeightedSum.class, MvPSeriesWeightedSum::new, "mv_pseries_weighted_sum"), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java new file mode 100644 index 0000000000000..6a63c38c924d9 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class MvMedianAbsoluteDeviationSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected MvMedianAbsoluteDeviation createTestInstance() { + return new MvMedianAbsoluteDeviation(randomSource(), randomChild()); + } + + @Override + protected MvMedianAbsoluteDeviation mutateInstance(MvMedianAbsoluteDeviation instance) throws IOException { + return new MvMedianAbsoluteDeviation( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) + ); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java index 1209193c02caf..b041faf6510a1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.function.Supplier; import java.util.stream.Stream; @@ -171,7 +172,7 @@ private static List overflowCasesFor( ? NumericUtils.asLongUnsigned(BigInteger.valueOf(1000)) : DataTypeConverter.convert(1000, type); - var typeName = type.name().toLowerCase(); + var typeName = type.name().toLowerCase(Locale.ROOT); return List.of( new TestCaseSupplier( From 5254feef84086f5de8133f733c6b30c922ff089b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 16:58:55 +0200 Subject: [PATCH 13/23] Added CSV tests and docs --- .../mv_median_absolute_deviation.asciidoc | 7 +++ .../mv_median_absolute_deviation.asciidoc | 13 ++++ .../mv_median_absolute_deviation.json | 60 +++++++++++++++++++ .../docs/mv_median_absolute_deviation.md | 12 ++++ .../mv_median_absolute_deviation.asciidoc | 15 +++++ .../esql/functions/mv-functions.asciidoc | 2 + .../mv_median_absolute_deviation.asciidoc | 6 ++ .../mv_median_absolute_deviation.svg | 1 + .../mv_median_absolute_deviation.asciidoc | 12 ++++ .../mv_median_absolute_deviation.csv-spec | 49 +++++++++++++++ .../xpack/esql/action/EsqlCapabilities.java | 5 ++ .../multivalue/MvMedianAbsoluteDeviation.java | 13 ++-- 12 files changed, 186 insertions(+), 9 deletions(-) create mode 100644 docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc create mode 100644 docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc create mode 100644 docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json create mode 100644 docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md create mode 100644 docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc create mode 100644 docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc create mode 100644 docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg create mode 100644 docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec diff --git a/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..a681670c545a9 --- /dev/null +++ b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,7 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +Converts a multivalued field into a single valued field containing the median absolute deviation. + +NOTE: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0. diff --git a/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..b36bc18a80174 --- /dev/null +++ b/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Example* + +[source.merge.styled,esql] +---- +include::{esql-specs}/mv_median_absolute_deviation.csv-spec[tag=example] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/mv_median_absolute_deviation.csv-spec[tag=example-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json new file mode 100644 index 0000000000000..a361722d2c44a --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json @@ -0,0 +1,60 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "eval", + "name" : "mv_median_absolute_deviation", + "description" : "Converts a multivalued field into a single valued field containing the median absolute deviation.", + "note" : "If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "long" + }, + { + "params" : [ + { + "name" : "number", + "type" : "unsigned_long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "unsigned_long" + } + ], + "examples" : [ + "ROW values = [0, 2, 5, 6]\n| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values)" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md new file mode 100644 index 0000000000000..5622cf7af6e50 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md @@ -0,0 +1,12 @@ + + +### MV_MEDIAN_ABSOLUTE_DEVIATION +Converts a multivalued field into a single valued field containing the median absolute deviation. + +``` +ROW values = [0, 2, 5, 6] +| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values) +``` +Note: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0. diff --git a/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..b594d589e6108 --- /dev/null +++ b/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-mv_median_absolute_deviation]] +=== `MV_MEDIAN_ABSOLUTE_DEVIATION` + +*Syntax* + +[.text-center] +image::esql/functions/signature/mv_median_absolute_deviation.svg[Embedded,opts=inline] + +include::../parameters/mv_median_absolute_deviation.asciidoc[] +include::../description/mv_median_absolute_deviation.asciidoc[] +include::../types/mv_median_absolute_deviation.asciidoc[] +include::../examples/mv_median_absolute_deviation.asciidoc[] diff --git a/docs/reference/esql/functions/mv-functions.asciidoc b/docs/reference/esql/functions/mv-functions.asciidoc index bd5f14cdd3557..4093e44c16911 100644 --- a/docs/reference/esql/functions/mv-functions.asciidoc +++ b/docs/reference/esql/functions/mv-functions.asciidoc @@ -17,6 +17,7 @@ * <> * <> * <> +* <> * <> * <> * <> @@ -34,6 +35,7 @@ include::layout/mv_first.asciidoc[] include::layout/mv_last.asciidoc[] include::layout/mv_max.asciidoc[] include::layout/mv_median.asciidoc[] +include::layout/mv_median_absolute_deviation.asciidoc[] include::layout/mv_min.asciidoc[] include::layout/mv_pseries_weighted_sum.asciidoc[] include::layout/mv_slice.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..47859c7e2b320 --- /dev/null +++ b/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: +Multivalue expression. diff --git a/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg b/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg new file mode 100644 index 0000000000000..7d8a131a91015 --- /dev/null +++ b/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg @@ -0,0 +1 @@ +MV_MEDIAN_ABSOLUTE_DEVIATION(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..d81bbf36ae3fe --- /dev/null +++ b/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,12 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | integer +long | long +unsigned_long | unsigned_long +|=== diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec new file mode 100644 index 0000000000000..35988b80b9da3 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec @@ -0,0 +1,49 @@ +example +required_capability: fn_mv_median_absolute_deviation + +// tag::example[] +ROW values = [0, 2, 5, 6] +| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values) +// end::example[] +; + +// tag::example-result[] +values:integer | median_absolute_deviation:integer | median:integer +[0, 2, 5, 6] | 2 | 3 +// end::example-result[] +; + + +multipleExpressions +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6] +| EVAL + MV_MEDIAN_ABSOLUTE_DEVIATION(x), + a = MV_MEDIAN_ABSOLUTE_DEVIATION(x), + b = MV_MEDIAN_ABSOLUTE_DEVIATION(TO_DOUBLE([2, 5])), + c = MV_MEDIAN_ABSOLUTE_DEVIATION(CASE(true, x, [0, 1])) +; + +x:integer | MV_MEDIAN_ABSOLUTE_DEVIATION(x):integer | a:integer | b:double | c:integer +[0, 2, 5, 6] | 2 | 2 | 1.5 | 2 +; + +nullsAndFolds +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6], single = 300 +| EVAL evalNull = null / 2, evalValue = 31 + 1 +| LIMIT 1 +| EVAL + a = MV_MEDIAN_ABSOLUTE_DEVIATION(x), + b = MV_MEDIAN_ABSOLUTE_DEVIATION(single), + c = MV_MEDIAN_ABSOLUTE_DEVIATION(null), + d = MV_MEDIAN_ABSOLUTE_DEVIATION(evalNull), + e = MV_MEDIAN_ABSOLUTE_DEVIATION(evalValue) +| KEEP a, b, c, d, e +; + +a:integer | b:integer | c:null | d:integer | e:integer +2 | 0 | null | null | 0 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 81b2ba71b8808..bd6665982857b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -37,6 +37,11 @@ public enum Cap { */ FN_MV_APPEND, + /** + * Support for {@code MV_MEDIAN_ABSOLUTE_DEVIATION} function. + */ + FN_MV_MEDIAN_ABSOLUTE_DEVIATION, + /** * Support for {@code MV_PERCENTILE} function. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index fce5255734366..007c317579819 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -47,15 +47,10 @@ public class MvMedianAbsoluteDeviation extends AbstractMultivalueFunction { @FunctionInfo( returnType = { "double", "integer", "long", "unsigned_long" }, description = "Converts a multivalued field into a single valued field containing the median absolute deviation.", - examples = { - @Example(file = "math", tag = "mv_median_absolute_deviation"), - @Example( - description = "If the field has an even number of values, " - + "the medians will be calculated as the average of the middle two values. " - + "If the column is not floating point, the average rounds *down*.", - file = "math", - tag = "mv_median_absolute_deviation_round_down" - ) } + note = "If the field has an even number of values, " + + "the medians will be calculated as the average of the middle two values. " + + "If the column is not floating point, the averages round towards 0.", + examples = @Example(file = "mv_median_absolute_deviation", tag = "example") ) public MvMedianAbsoluteDeviation( Source source, From f4c312c3d4022428cfa8de33bff5ffd44089ad8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 17:17:55 +0200 Subject: [PATCH 14/23] Added csv tests for all types and using an index --- .../mv_median_absolute_deviation.csv-spec | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec index 35988b80b9da3..a8dbd1d384b80 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec @@ -13,6 +13,39 @@ values:integer | median_absolute_deviation:integer | median:integer // end::example-result[] ; +fromIndex +required_capability: fn_mv_median_absolute_deviation + +FROM employees +| WHERE emp_no <= 10002 +| SORT emp_no ASC +| EVAL + int = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.int), + long = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.long), + double = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change) +| KEEP emp_no, int, long, double +; + +emp_no:integer | int:integer | long:long | double:double +10001 | 0 | 0 | 0 +10002 | 9 | 9 | 9.2 +; + +allTypes +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6] +| EVAL + int = MV_MEDIAN_ABSOLUTE_DEVIATION(x::INTEGER), + long = MV_MEDIAN_ABSOLUTE_DEVIATION(x::LONG), + double = MV_MEDIAN_ABSOLUTE_DEVIATION(x::DOUBLE), + ul = MV_MEDIAN_ABSOLUTE_DEVIATION(x::UNSIGNED_LONG) +| KEEP int, long, double, ul +; + +int:integer | long:long | double:double | ul:unsigned_long +2 | 2 | 2 | 2 +; multipleExpressions required_capability: fn_mv_median_absolute_deviation From 80966541a896e0a8e32516f21ee6db6d87165740 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 17:23:47 +0200 Subject: [PATCH 15/23] Assert doubles median is finite --- .../scalar/multivalue/MvMedianAbsoluteDeviation.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 007c317579819..685c3d01c4d24 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -226,8 +226,9 @@ static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) + values.getDouble(middle)) / 2; for (int i = 0; i < count; i++) { double value = values.getDouble(firstValue + i); + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. doubles.values[i] = value > median ? value - median : median - value; - assert Double.isFinite(doubles.values[i]) : "Overflow on median differences"; } return doubleMedianOf(doubles.values, count); } @@ -240,7 +241,8 @@ static double doubleMedianOf(double[] values, int count) { // TODO quickselect Arrays.sort(values, 0, count); int middle = count / 2; - return count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); + double median = count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); + return NumericUtils.asFiniteNumber(median); } @MvEvaluator( From 80095022490b77870e2d3bcaa8d5b775a3b878c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Thu, 29 Aug 2024 18:03:11 +0200 Subject: [PATCH 16/23] Surrogate in aggregation --- .../main/resources/stats_percentile.csv-spec | 15 +++++++++++++++ .../aggregate/MedianAbsoluteDeviation.java | 17 ++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec index 2ac7a0cf6217a..a23fd0d132566 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec @@ -143,6 +143,21 @@ MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double // end::median-absolute-deviation-result[] ; +medianAbsoluteDeviationFold +ROW x = [0, 2, 5, 6] +| STATS + int_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::integer), + int_var = MEDIAN_ABSOLUTE_DEVIATION(x::integer), + long_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::long), + long_var = MEDIAN_ABSOLUTE_DEVIATION(x::long), + double_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::double), + double_var = MEDIAN_ABSOLUTE_DEVIATION(x::double) +; + +int_constant:double | int_var:double | long_constant:double | long_var:double | double_constant:double | double_var:double +2 | 2 | 2 | 2 | 2 | 2 +; + medianViaExpression from employees | stats p50 = percentile(salary_change, 25*2); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java index 46661e96b1d48..f3f396033e031 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java @@ -16,14 +16,17 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.SurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviation; import java.io.IOException; import java.util.List; -public class MedianAbsoluteDeviation extends NumericAggregate { +public class MedianAbsoluteDeviation extends NumericAggregate implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "MedianAbsoluteDeviation", @@ -97,4 +100,16 @@ protected AggregatorFunctionSupplier intSupplier(List inputChannels) { protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { return new MedianAbsoluteDeviationDoubleAggregatorFunctionSupplier(inputChannels); } + + @Override + public Expression surrogate() { + var s = source(); + var field = field(); + + if (field.foldable()) { + return new MvMedianAbsoluteDeviation(s, new ToDouble(s, field)); + } + + return null; + } } From 0ca55ecd898ecbc4347709394d7464b9e7de6147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Sep 2024 13:38:15 +0200 Subject: [PATCH 17/23] Extra cases for FROM test --- .../resources/mv_median_absolute_deviation.csv-spec | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec index a8dbd1d384b80..f648fc5630469 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec @@ -17,18 +17,20 @@ fromIndex required_capability: fn_mv_median_absolute_deviation FROM employees -| WHERE emp_no <= 10002 +| WHERE emp_no IN (10001, 10002, 10007, 10009) | SORT emp_no ASC | EVAL int = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.int), long = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.long), double = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change) -| KEEP emp_no, int, long, double +| KEEP emp_no, salary_change, int, long, double ; -emp_no:integer | int:integer | long:long | double:double -10001 | 0 | 0 | 0 -10002 | 9 | 9 | 9.2 +emp_no:integer | salary_change:double | int:integer | long:long | double:double +10001 | 1.19 | 0 | 0 | 0 +10002 | [-7.23, 11.17] | 9 | 9 | 9.2 +10007 | [-7.06,0.57,1.99] | 1 | 1 | 1.42 +10009 | null | null | null | null ; allTypes From 47965c129bd4109e856ffeef352068aaf93bc4a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Sep 2024 14:15:42 +0200 Subject: [PATCH 18/23] Improved function documentation --- .../scalar/multivalue/MvMedianAbsoluteDeviation.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 685c3d01c4d24..e2a1d9b2dec80 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -46,10 +46,14 @@ public class MvMedianAbsoluteDeviation extends AbstractMultivalueFunction { @FunctionInfo( returnType = { "double", "integer", "long", "unsigned_long" }, - description = "Converts a multivalued field into a single valued field containing the median absolute deviation.", + description = "Converts a multivalued field into a single valued field containing the median absolute deviation." + + "\n\n" + + "It is calculated as the median of each data point's deviation from the median of " + + "the entire sample. That is, for a random variable `X`, the median absolute " + + "deviation is `median(|median(X) - X|)`.", note = "If the field has an even number of values, " + "the medians will be calculated as the average of the middle two values. " - + "If the column is not floating point, the averages round towards 0.", + + "If the value is not a floating point number, the averages are rounded towards 0.", examples = @Example(file = "mv_median_absolute_deviation", tag = "example") ) public MvMedianAbsoluteDeviation( From a07c58fe05663308f1afcf8d269914e8a6883b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 3 Sep 2024 14:35:47 +0200 Subject: [PATCH 19/23] Moved MAD agg csv tests to its own file and updated meta tests --- .../mv_median_absolute_deviation.asciidoc | 4 +- .../median_absolute_deviation.asciidoc | 8 ++-- .../mv_median_absolute_deviation.json | 4 +- .../docs/mv_median_absolute_deviation.md | 4 +- .../median_absolute_deviation.csv-spec | 40 ++++++++++++++++++ .../src/main/resources/meta.csv-spec | 2 +- .../main/resources/stats_percentile.csv-spec | 41 ------------------- .../aggregate/MedianAbsoluteDeviation.java | 4 +- 8 files changed, 54 insertions(+), 53 deletions(-) create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec diff --git a/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc index a681670c545a9..765c4d322c3dc 100644 --- a/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc +++ b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc @@ -2,6 +2,6 @@ *Description* -Converts a multivalued field into a single valued field containing the median absolute deviation. +Converts a multivalued field into a single valued field containing the median absolute deviation. It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`. -NOTE: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0. +NOTE: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0. diff --git a/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc b/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc index 20891126c20fb..9084c008e890a 100644 --- a/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc +++ b/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc @@ -4,19 +4,19 @@ [source.merge.styled,esql] ---- -include::{esql-specs}/stats_percentile.csv-spec[tag=median-absolute-deviation] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=median-absolute-deviation] ---- [%header.monospaced.styled,format=dsv,separator=|] |=== -include::{esql-specs}/stats_percentile.csv-spec[tag=median-absolute-deviation-result] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=median-absolute-deviation-result] |=== The expression can use inline functions. For example, to calculate the the median absolute deviation of the maximum values of a multivalued column, first use `MV_MAX` to get the maximum value per row, and use the result with the `MEDIAN_ABSOLUTE_DEVIATION` function [source.merge.styled,esql] ---- -include::{esql-specs}/stats_percentile.csv-spec[tag=docsStatsMADNestedExpression] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=docsStatsMADNestedExpression] ---- [%header.monospaced.styled,format=dsv,separator=|] |=== -include::{esql-specs}/stats_percentile.csv-spec[tag=docsStatsMADNestedExpression-result] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=docsStatsMADNestedExpression-result] |=== diff --git a/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json index a361722d2c44a..d6f1174a4e259 100644 --- a/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json +++ b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json @@ -2,8 +2,8 @@ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", "type" : "eval", "name" : "mv_median_absolute_deviation", - "description" : "Converts a multivalued field into a single valued field containing the median absolute deviation.", - "note" : "If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0.", + "description" : "Converts a multivalued field into a single valued field containing the median absolute deviation.\n\nIt is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`.", + "note" : "If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0.", "signatures" : [ { "params" : [ diff --git a/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md index 5622cf7af6e50..191ce3ce60ae1 100644 --- a/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md +++ b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md @@ -5,8 +5,10 @@ This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../READ ### MV_MEDIAN_ABSOLUTE_DEVIATION Converts a multivalued field into a single valued field containing the median absolute deviation. +It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`. + ``` ROW values = [0, 2, 5, 6] | EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values) ``` -Note: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the column is not floating point, the averages round towards 0. +Note: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0. diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec new file mode 100644 index 0000000000000..300a2ef86d0f9 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec @@ -0,0 +1,40 @@ +medianAbsoluteDeviation +// tag::median-absolute-deviation[] +FROM employees +| STATS MEDIAN(salary), MEDIAN_ABSOLUTE_DEVIATION(salary) +// end::median-absolute-deviation[] +; + +// tag::median-absolute-deviation-result[] +MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double +47003 | 10096.5 +// end::median-absolute-deviation-result[] +; + +medianAbsoluteDeviationFold +ROW x = [0, 2, 5, 6] +| STATS + int_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::integer), + int_var = MEDIAN_ABSOLUTE_DEVIATION(x::integer), + long_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::long), + long_var = MEDIAN_ABSOLUTE_DEVIATION(x::long), + double_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::double), + double_var = MEDIAN_ABSOLUTE_DEVIATION(x::double) +; + +int_constant:double | int_var:double | long_constant:double | long_var:double | double_constant:double | double_var:double +2 | 2 | 2 | 2 | 2 | 2 +; + +docsStatsMADNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] +// tag::docsStatsMADNestedExpression[] +FROM employees +| STATS m_a_d_max_salary_change = MEDIAN_ABSOLUTE_DEVIATION(MV_MAX(salary_change)) +// end::docsStatsMADNestedExpression[] +; + +// tag::docsStatsMADNestedExpression-result[] +m_a_d_max_salary_change:double +5.69 +// end::docsStatsMADNestedExpression-result[] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index cc45a42a8e64c..4501f6cf6e954 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -303,7 +303,7 @@ mv_first |Converts a multivalued expression into a single valued column con mv_last |Converts a multivalue expression into a single valued column containing the last value. This is most useful when reading from a function that emits multivalued columns in a known order like <>. The order that <> are read from underlying storage is not guaranteed. It is *frequently* ascending, but don't rely on that. If you need the maximum value use <> instead of `MV_LAST`. `MV_MAX` has optimizations for sorted values so there isn't a performance benefit to `MV_LAST`. mv_max |Converts a multivalued expression into a single valued column containing the maximum value. mv_median |Converts a multivalued field into a single valued field containing the median value. -mv_median_abso|Converts a multivalued field into a single valued field containing the median absolute deviation. +mv_median_abso|"Converts a multivalued field into a single valued field containing the median absolute deviation. It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`." mv_min |Converts a multivalued expression into a single valued column containing the minimum value. mv_percentile |Converts a multivalued field into a single valued field containing the value at which a certain percentage of observed values occur. mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec index a23fd0d132566..7f4b8b24dbe63 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec @@ -130,34 +130,6 @@ m:double | p50:double | job_positions:keyword 3.9299999999999997 | 3.9299999999999997 | "Architect" ; -medianAbsoluteDeviation -// tag::median-absolute-deviation[] -FROM employees -| STATS MEDIAN(salary), MEDIAN_ABSOLUTE_DEVIATION(salary) -// end::median-absolute-deviation[] -; - -// tag::median-absolute-deviation-result[] -MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double -47003 | 10096.5 -// end::median-absolute-deviation-result[] -; - -medianAbsoluteDeviationFold -ROW x = [0, 2, 5, 6] -| STATS - int_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::integer), - int_var = MEDIAN_ABSOLUTE_DEVIATION(x::integer), - long_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::long), - long_var = MEDIAN_ABSOLUTE_DEVIATION(x::long), - double_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::double), - double_var = MEDIAN_ABSOLUTE_DEVIATION(x::double) -; - -int_constant:double | int_var:double | long_constant:double | long_var:double | double_constant:double | double_var:double -2 | 2 | 2 | 2 | 2 | 2 -; - medianViaExpression from employees | stats p50 = percentile(salary_change, 25*2); @@ -185,19 +157,6 @@ median_max_salary_change:double // end::docsStatsMedianNestedExpression-result[] ; -docsStatsMADNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] -// tag::docsStatsMADNestedExpression[] -FROM employees -| STATS m_a_d_max_salary_change = MEDIAN_ABSOLUTE_DEVIATION(MV_MAX(salary_change)) -// end::docsStatsMADNestedExpression[] -; - -// tag::docsStatsMADNestedExpression-result[] -m_a_d_max_salary_change:double -5.69 -// end::docsStatsMADNestedExpression-result[] -; - docsStatsPercentileNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] // tag::docsStatsPercentileNestedExpression[] FROM employees diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java index f3f396033e031..23a6b23a35cde 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java @@ -53,13 +53,13 @@ public class MedianAbsoluteDeviation extends NumericAggregate implements Surroga ====""", isAggregation = true, examples = { - @Example(file = "stats_percentile", tag = "median-absolute-deviation"), + @Example(file = "median_absolute_deviation", tag = "median-absolute-deviation"), @Example( description = "The expression can use inline functions. For example, to calculate the the " + "median absolute deviation of the maximum values of a multivalued column, first " + "use `MV_MAX` to get the maximum value per row, and use the result with the " + "`MEDIAN_ABSOLUTE_DEVIATION` function", - file = "stats_percentile", + file = "median_absolute_deviation", tag = "docsStatsMADNestedExpression" ), } ) From 09b1f31a5a1c30a2883fef2460e77b49a0406b1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Sep 2024 12:44:05 +0200 Subject: [PATCH 20/23] Added missing ascending case for doubles, fixed doubles infinites bug, and exceptions not resetting count --- ...edianAbsoluteDeviationDoubleEvaluator.java | 46 ++++ .../multivalue/MvMedianAbsoluteDeviation.java | 225 ++++++++++-------- 2 files changed, 177 insertions(+), 94 deletions(-) diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java index f576e02b142b1..7cefde819dedc 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java @@ -32,6 +32,9 @@ public String name() { */ @Override public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { @@ -66,6 +69,9 @@ public Block evalNullable(Block fieldVal) { */ @Override public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { @@ -137,6 +143,46 @@ public Block evalSingleValuedNotNullable(Block fieldVal) { } } + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + double result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + double result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { private final EvalOperator.ExpressionEvaluator.Factory field; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index e2a1d9b2dec80..70a377a378ea3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -108,6 +108,12 @@ static class Longs { public int count; } + /** + * Evaluator for integers. + *

+ * To avoid integer overflows, we're using the {@link Longs} class to store the values. + *

+ */ @MvEvaluator(extraName = "Int", finish = "finishInts", ascending = "ascending", single = "single") static void process(Longs longs, int v) { if (longs.values.length < longs.count + 1) { @@ -117,31 +123,38 @@ static void process(Longs longs, int v) { } static int finishInts(Longs longs) { - long median = longMedianOf(longs.values, longs.count); - for (int i = 0; i < longs.count; i++) { - long value = longs.values[i]; - // We know they were ints, so we can calculate differences within a long - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + try { + long median = longMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + // We know they were ints, so we can calculate differences within a long + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + } + return Math.toIntExact(longMedianOf(longs)); + } finally { + longs.count = 0; } - int mad = Math.toIntExact(longMedianOf(longs.values, longs.count)); - longs.count = 0; - return mad; } /** - * If the values are ascending pick the middle value or average the two middle values together. + * If the values are ascending, we avoid the initial sorting. */ static int ascending(Longs longs, IntBlock values, int firstValue, int count) { - if (longs.values.length < count) { - longs.values = ArrayUtil.grow(longs.values, count); - } - int middle = firstValue + count / 2; - long median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); - for (int i = 0; i < count; i++) { - long value = values.getInt(firstValue + i); - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); + for (int i = 0; i < count; i++) { + long value = values.getInt(firstValue + i); + longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + } + return Math.toIntExact(longMedianOf(longs)); + } finally { + longs.count = 0; } - return Math.toIntExact(longMedianOf(longs.values, count)); } static int single(int value) { @@ -157,43 +170,50 @@ static void process(Longs longs, long v) { } static long finish(Longs longs) { - long median = longMedianOf(longs.values, longs.count); - for (int i = 0; i < longs.count; i++) { - long value = longs.values[i]; - // From here, this array contains unsigned longs - longs.values[i] = unsignedDifference(value, median); + try { + long median = longMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); + } + return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs)); + } finally { + longs.count = 0; } - long mad = NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs.values, longs.count)); - longs.count = 0; - return mad; } /** - * If the values are ascending pick the middle value or average the two middle values together. + * If the values are ascending, we avoid the initial sorting. */ static long ascending(Longs longs, LongBlock values, int firstValue, int count) { - if (longs.values.length < count) { - longs.values = ArrayUtil.grow(longs.values, count); - } - int middle = firstValue + count / 2; - long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); - for (int i = 0; i < count; i++) { - long value = values.getLong(firstValue + i); - // From here, this array contains unsigned longs - longs.values[i] = unsignedDifference(value, median); + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); + } + return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs)); + } finally { + longs.count = 0; } - return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs.values, count)); } static long single(long value) { return 0L; } - static long longMedianOf(long[] values, int count) { + static long longMedianOf(Longs longs) { // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - return count % 2 == 1 ? values[middle] : avgWithoutOverflow(values[middle - 1], values[middle]); + Arrays.sort(longs.values, 0, longs.count); + int middle = longs.count / 2; + return longs.count % 2 == 1 ? longs.values[middle] : avgWithoutOverflow(longs.values[middle - 1], longs.values[middle]); } static class Doubles { @@ -201,7 +221,7 @@ static class Doubles { public int count; } - @MvEvaluator(extraName = "Double", finish = "finish", single = "single") + @MvEvaluator(extraName = "Double", finish = "finish", ascending = "ascending", single = "single") static void process(Doubles doubles, double v) { if (doubles.values.length < doubles.count + 1) { doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); @@ -210,42 +230,52 @@ static void process(Doubles doubles, double v) { } static double finish(Doubles doubles) { - double median = doubleMedianOf(doubles.values, doubles.count); - for (int i = 0; i < doubles.count; i++) { - double value = doubles.values[i]; - // Double differences between median and the values may potentially result in +/-Infinity. - // As we use that value just to sort, the MAD should remain finite. - doubles.values[i] = value > median ? value - median : median - value; + try { + double median = doubleMedianOf(doubles); + for (int i = 0; i < doubles.count; i++) { + double value = doubles.values[i]; + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. + doubles.values[i] = value > median ? value - median : median - value; + } + return doubleMedianOf(doubles); + } finally { + doubles.count = 0; } - double mad = doubleMedianOf(doubles.values, doubles.count); - doubles.count = 0; - return mad; } + /** + * If the values are ascending, we avoid the initial sorting. + */ static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { - if (doubles.values.length < count) { - doubles.values = ArrayUtil.grow(doubles.values, count); - } - int middle = firstValue + count / 2; - double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) + values.getDouble(middle)) / 2; - for (int i = 0; i < count; i++) { - double value = values.getDouble(firstValue + i); - // Double differences between median and the values may potentially result in +/-Infinity. - // As we use that value just to sort, the MAD should remain finite. - doubles.values[i] = value > median ? value - median : median - value; + try { + if (doubles.values.length < count) { + doubles.values = ArrayUtil.grow(doubles.values, count); + } + doubles.count = count; + int middle = firstValue + count / 2; + double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) / 2 + values.getDouble(middle) / 2); + for (int i = 0; i < count; i++) { + double value = values.getDouble(firstValue + i); + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. + doubles.values[i] = value > median ? value - median : median - value; + } + return doubleMedianOf(doubles); + } finally { + doubles.count = 0; } - return doubleMedianOf(doubles.values, count); } static double single(double value) { return 0.; } - static double doubleMedianOf(double[] values, int count) { + static double doubleMedianOf(Doubles doubles) { // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - double median = count % 2 == 1 ? values[middle] : (values[middle - 1] / 2 + values[middle] / 2); + Arrays.sort(doubles.values, 0, doubles.count); + int middle = doubles.count / 2; + double median = doubles.count % 2 == 1 ? doubles.values[middle] : (doubles.values[middle - 1] / 2 + doubles.values[middle] / 2); return NumericUtils.asFiniteNumber(median); } @@ -260,49 +290,56 @@ static void processUnsignedLong(Longs longs, long v) { } static long finishUnsignedLong(Longs longs) { - long median = unsignedLongMedianOf(longs.values, longs.count); - for (int i = 0; i < longs.count; i++) { - long value = longs.values[i]; - longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + try { + long median = unsignedLongMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + } + return unsignedLongMedianOf(longs); + } finally { + longs.count = 0; } - long mad = unsignedLongMedianOf(longs.values, longs.count); - longs.count = 0; - return mad; } /** - * If the values are ascending pick the middle value or average the two middle values together. + * If the values are ascending, we avoid the initial sorting. */ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) { - if (longs.values.length < count) { - longs.values = ArrayUtil.grow(longs.values, count); - } - int middle = firstValue + count / 2; - long median; - if (count % 2 == 1) { - median = values.getLong(middle); - } else { - median = unsignedLongAvgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); - } - for (int i = 0; i < count; i++) { - long value = values.getLong(firstValue + i); - longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median; + if (count % 2 == 1) { + median = values.getLong(middle); + } else { + median = unsignedLongAvgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); + } + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + } + return unsignedLongMedianOf(longs); + } finally { + longs.count = 0; } - return unsignedLongMedianOf(longs.values, count); } static long singleUnsignedLong(long value) { return NumericUtils.ZERO_AS_UNSIGNED_LONG; } - static long unsignedLongMedianOf(long[] values, int count) { + static long unsignedLongMedianOf(Longs longs) { // TODO quickselect - Arrays.sort(values, 0, count); - int middle = count / 2; - if (count % 2 == 1) { - return values[middle]; + Arrays.sort(longs.values, 0, longs.count); + int middle = longs.count / 2; + if (longs.count % 2 == 1) { + return longs.values[middle]; } - return unsignedLongAvgWithoutOverflow(values[middle - 1], values[middle]); + return unsignedLongAvgWithoutOverflow(longs.values[middle - 1], longs.values[middle]); } // Utility methods From f11a7073f967f7e0962f4f9a4ec9e4b2970bb01f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Sep 2024 12:55:37 +0200 Subject: [PATCH 21/23] Simplified ints calculation by removing long overflow safeguards --- .../function/scalar/multivalue/MvMedianAbsoluteDeviation.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index 70a377a378ea3..dc91fadd49532 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -128,7 +128,7 @@ static int finishInts(Longs longs) { for (int i = 0; i < longs.count; i++) { long value = longs.values[i]; // We know they were ints, so we can calculate differences within a long - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + longs.values[i] = value > median ? value - median : median - value; } return Math.toIntExact(longMedianOf(longs)); } finally { @@ -149,7 +149,7 @@ static int ascending(Longs longs, IntBlock values, int firstValue, int count) { long median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); for (int i = 0; i < count; i++) { long value = values.getInt(firstValue + i); - longs.values[i] = value > median ? Math.subtractExact(value, median) : Math.subtractExact(median, value); + longs.values[i] = value > median ? value - median : median - value; } return Math.toIntExact(longMedianOf(longs)); } finally { From a0a61cad5cec29aa9f93053b2dd9cf687983583d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Sep 2024 15:11:03 +0200 Subject: [PATCH 22/23] Add required capability to agg test --- .../src/main/resources/median_absolute_deviation.csv-spec | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec index 300a2ef86d0f9..9427ef0a30973 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec @@ -12,6 +12,7 @@ MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double ; medianAbsoluteDeviationFold +required_capability: fn_mv_median_absolute_deviation ROW x = [0, 2, 5, 6] | STATS int_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::integer), From 42377cee76a7825dcff1bbf58a05cba5bc753d2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Fri, 6 Sep 2024 16:36:52 +0200 Subject: [PATCH 23/23] Improved docs on `ascending` functions --- .../multivalue/MvMedianAbsoluteDeviation.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java index dc91fadd49532..9bdfd1a2ccafc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -137,7 +137,9 @@ static int finishInts(Longs longs) { } /** - * If the values are ascending, we avoid the initial sorting. + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. */ static int ascending(Longs longs, IntBlock values, int firstValue, int count) { try { @@ -184,7 +186,9 @@ static long finish(Longs longs) { } /** - * If the values are ascending, we avoid the initial sorting. + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. */ static long ascending(Longs longs, LongBlock values, int firstValue, int count) { try { @@ -245,7 +249,9 @@ static double finish(Doubles doubles) { } /** - * If the values are ascending, we avoid the initial sorting. + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. */ static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { try { @@ -303,7 +309,9 @@ static long finishUnsignedLong(Longs longs) { } /** - * If the values are ascending, we avoid the initial sorting. + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. */ static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) { try {