diff --git a/docs/changelog/112055.yaml b/docs/changelog/112055.yaml new file mode 100644 index 0000000000000..cdf15b3b37468 --- /dev/null +++ b/docs/changelog/112055.yaml @@ -0,0 +1,6 @@ +pr: 112055 +summary: "ESQL: `mv_median_absolute_deviation` function" +area: ES|QL +type: feature +issues: + - 111590 diff --git a/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..765c4d322c3dc --- /dev/null +++ b/docs/reference/esql/functions/description/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,7 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +Converts a multivalued field into a single valued field containing the median absolute deviation. It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`. + +NOTE: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0. diff --git a/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc b/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc index 20891126c20fb..9084c008e890a 100644 --- a/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc +++ b/docs/reference/esql/functions/examples/median_absolute_deviation.asciidoc @@ -4,19 +4,19 @@ [source.merge.styled,esql] ---- -include::{esql-specs}/stats_percentile.csv-spec[tag=median-absolute-deviation] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=median-absolute-deviation] ---- [%header.monospaced.styled,format=dsv,separator=|] |=== -include::{esql-specs}/stats_percentile.csv-spec[tag=median-absolute-deviation-result] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=median-absolute-deviation-result] |=== The expression can use inline functions. For example, to calculate the the median absolute deviation of the maximum values of a multivalued column, first use `MV_MAX` to get the maximum value per row, and use the result with the `MEDIAN_ABSOLUTE_DEVIATION` function [source.merge.styled,esql] ---- -include::{esql-specs}/stats_percentile.csv-spec[tag=docsStatsMADNestedExpression] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=docsStatsMADNestedExpression] ---- [%header.monospaced.styled,format=dsv,separator=|] |=== -include::{esql-specs}/stats_percentile.csv-spec[tag=docsStatsMADNestedExpression-result] +include::{esql-specs}/median_absolute_deviation.csv-spec[tag=docsStatsMADNestedExpression-result] |=== diff --git a/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..b36bc18a80174 --- /dev/null +++ b/docs/reference/esql/functions/examples/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Example* + +[source.merge.styled,esql] +---- +include::{esql-specs}/mv_median_absolute_deviation.csv-spec[tag=example] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/mv_median_absolute_deviation.csv-spec[tag=example-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json new file mode 100644 index 0000000000000..d6f1174a4e259 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/mv_median_absolute_deviation.json @@ -0,0 +1,60 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "eval", + "name" : "mv_median_absolute_deviation", + "description" : "Converts a multivalued field into a single valued field containing the median absolute deviation.\n\nIt is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`.", + "note" : "If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "long" + }, + { + "params" : [ + { + "name" : "number", + "type" : "unsigned_long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "unsigned_long" + } + ], + "examples" : [ + "ROW values = [0, 2, 5, 6]\n| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values)" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md new file mode 100644 index 0000000000000..191ce3ce60ae1 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/mv_median_absolute_deviation.md @@ -0,0 +1,14 @@ + + +### MV_MEDIAN_ABSOLUTE_DEVIATION +Converts a multivalued field into a single valued field containing the median absolute deviation. + +It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`. + +``` +ROW values = [0, 2, 5, 6] +| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values) +``` +Note: If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0. diff --git a/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..b594d589e6108 --- /dev/null +++ b/docs/reference/esql/functions/layout/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-mv_median_absolute_deviation]] +=== `MV_MEDIAN_ABSOLUTE_DEVIATION` + +*Syntax* + +[.text-center] +image::esql/functions/signature/mv_median_absolute_deviation.svg[Embedded,opts=inline] + +include::../parameters/mv_median_absolute_deviation.asciidoc[] +include::../description/mv_median_absolute_deviation.asciidoc[] +include::../types/mv_median_absolute_deviation.asciidoc[] +include::../examples/mv_median_absolute_deviation.asciidoc[] diff --git a/docs/reference/esql/functions/mv-functions.asciidoc b/docs/reference/esql/functions/mv-functions.asciidoc index bd5f14cdd3557..4093e44c16911 100644 --- a/docs/reference/esql/functions/mv-functions.asciidoc +++ b/docs/reference/esql/functions/mv-functions.asciidoc @@ -17,6 +17,7 @@ * <> * <> * <> +* <> * <> * <> * <> @@ -34,6 +35,7 @@ include::layout/mv_first.asciidoc[] include::layout/mv_last.asciidoc[] include::layout/mv_max.asciidoc[] include::layout/mv_median.asciidoc[] +include::layout/mv_median_absolute_deviation.asciidoc[] include::layout/mv_min.asciidoc[] include::layout/mv_pseries_weighted_sum.asciidoc[] include::layout/mv_slice.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..47859c7e2b320 --- /dev/null +++ b/docs/reference/esql/functions/parameters/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: +Multivalue expression. diff --git a/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg b/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg new file mode 100644 index 0000000000000..7d8a131a91015 --- /dev/null +++ b/docs/reference/esql/functions/signature/mv_median_absolute_deviation.svg @@ -0,0 +1 @@ +MV_MEDIAN_ABSOLUTE_DEVIATION(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc b/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc new file mode 100644 index 0000000000000..d81bbf36ae3fe --- /dev/null +++ b/docs/reference/esql/functions/types/mv_median_absolute_deviation.asciidoc @@ -0,0 +1,12 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | integer +long | long +unsigned_long | unsigned_long +|=== diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java index 3bff45db5023c..1f960f79c7901 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java @@ -88,6 +88,18 @@ public static Number unsignedLongAsNumber(long l) { return l < 0 ? twosComplement(l) : LONG_MAX_PLUS_ONE_AS_BIGINTEGER.add(BigInteger.valueOf(l)); } + /** + * Converts an unsigned long value "encoded" into a (signed) long. + * In case of overflow, an ArithmeticException is thrown. + */ + public static long unsignedLongAsLongExact(long l) { + if (l < 0) { + return twosComplement(l); + } + + throw new ArithmeticException(UNSIGNED_LONG_OVERFLOW); + } + public static BigInteger unsignedLongAsBigInteger(long l) { return l < 0 ? BigInteger.valueOf(twosComplement(l)) : LONG_MAX_PLUS_ONE_AS_BIGINTEGER.add(BigInteger.valueOf(l)); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java index 993b8363fb35f..d6e062facdbfd 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/MvEvaluatorImplementer.java @@ -107,7 +107,7 @@ public MvEvaluatorImplementer( this.resultType = TypeName.get(processFunction.getReturnType()); } this.singleValueFunction = SingleValueFunction.from(declarationType, singleValueMethodName, resultType, fieldType); - this.ascendingFunction = AscendingFunction.from(this, declarationType, ascendingMethodName); + this.ascendingFunction = AscendingFunction.from(this, declarationType, workType, ascendingMethodName); this.warnExceptions = warnExceptions; this.implementation = ClassName.get( elements.getPackageOf(declarationType).toString(), @@ -511,7 +511,7 @@ private void call(MethodSpec.Builder builder) { * Function handling blocks of ascending values. */ private class AscendingFunction { - static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarationType, String name) { + static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarationType, TypeName workType, String name) { if (name.equals("")) { return null; } @@ -523,8 +523,9 @@ static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarati m -> m.getParameters().size() == 1 && m.getParameters().get(0).asType().getKind() == TypeKind.INT ); if (fn != null) { - return impl.new AscendingFunction(fn, false); + return impl.new AscendingFunction(fn, false, false); } + // Block mode without work parameter fn = findMethod( declarationType, new String[] { name }, @@ -532,17 +533,31 @@ static AscendingFunction from(MvEvaluatorImplementer impl, TypeElement declarati && m.getParameters().get(1).asType().getKind() == TypeKind.INT && m.getParameters().get(2).asType().getKind() == TypeKind.INT ); - if (fn == null) { - throw new IllegalArgumentException("Couldn't find " + declarationType + "#" + name + "(block, int, int)"); + if (fn != null) { + return impl.new AscendingFunction(fn, true, false); } - return impl.new AscendingFunction(fn, true); + // Block mode with work parameter + fn = findMethod( + declarationType, + new String[] { name }, + m -> m.getParameters().size() == 4 + && TypeName.get(m.getParameters().get(0).asType()).equals(workType) + && m.getParameters().get(2).asType().getKind() == TypeKind.INT + && m.getParameters().get(3).asType().getKind() == TypeKind.INT + ); + if (fn != null) { + return impl.new AscendingFunction(fn, true, true); + } + throw new IllegalArgumentException("Couldn't find " + declarationType + "#" + name + "(block, int, int)"); } private final List invocationArgs = new ArrayList<>(); private final boolean blockMode; + private final boolean withWorkParameter; - private AscendingFunction(ExecutableElement fn, boolean blockMode) { + private AscendingFunction(ExecutableElement fn, boolean blockMode, boolean withWorkParameter) { this.blockMode = blockMode; + this.withWorkParameter = withWorkParameter; if (blockMode) { invocationArgs.add(resultType); } @@ -552,7 +567,11 @@ private AscendingFunction(ExecutableElement fn, boolean blockMode) { private void call(MethodSpec.Builder builder) { if (blockMode) { - builder.addStatement("$T result = $T.$L(v, first, valueCount)", invocationArgs.toArray()); + if (withWorkParameter) { + builder.addStatement("$T result = $T.$L(work, v, first, valueCount)", invocationArgs.toArray()); + } else { + builder.addStatement("$T result = $T.$L(v, first, valueCount)", invocationArgs.toArray()); + } } else { builder.addStatement("int idx = $T.$L(valueCount)", invocationArgs.toArray()); fetch(builder, "result", resultType, "first + idx", workType.equals(fieldType) ? "firstScratch" : "valueScratch"); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec new file mode 100644 index 0000000000000..9427ef0a30973 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/median_absolute_deviation.csv-spec @@ -0,0 +1,41 @@ +medianAbsoluteDeviation +// tag::median-absolute-deviation[] +FROM employees +| STATS MEDIAN(salary), MEDIAN_ABSOLUTE_DEVIATION(salary) +// end::median-absolute-deviation[] +; + +// tag::median-absolute-deviation-result[] +MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double +47003 | 10096.5 +// end::median-absolute-deviation-result[] +; + +medianAbsoluteDeviationFold +required_capability: fn_mv_median_absolute_deviation +ROW x = [0, 2, 5, 6] +| STATS + int_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::integer), + int_var = MEDIAN_ABSOLUTE_DEVIATION(x::integer), + long_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::long), + long_var = MEDIAN_ABSOLUTE_DEVIATION(x::long), + double_constant = MEDIAN_ABSOLUTE_DEVIATION([0, 2, 5, 6]::double), + double_var = MEDIAN_ABSOLUTE_DEVIATION(x::double) +; + +int_constant:double | int_var:double | long_constant:double | long_var:double | double_constant:double | double_var:double +2 | 2 | 2 | 2 | 2 | 2 +; + +docsStatsMADNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] +// tag::docsStatsMADNestedExpression[] +FROM employees +| STATS m_a_d_max_salary_change = MEDIAN_ABSOLUTE_DEVIATION(MV_MAX(salary_change)) +// end::docsStatsMADNestedExpression[] +; + +// tag::docsStatsMADNestedExpression-result[] +m_a_d_max_salary_change:double +5.69 +// end::docsStatsMADNestedExpression-result[] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 325b984c36d34..f679cd333f76a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -53,6 +53,7 @@ double e() "boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version mv_last(field:boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version)" "boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version mv_max(field:boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version)" "double|integer|long|unsigned_long mv_median(number:double|integer|long|unsigned_long)" +"double|integer|long|unsigned_long mv_median_absolute_deviation(number:double|integer|long|unsigned_long)" "boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version mv_min(field:boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version)" "double|integer|long mv_percentile(number:double|integer|long, percentile:double|integer|long)" "double mv_pseries_weighted_sum(number:double, p:double)" @@ -177,6 +178,7 @@ mv_first |field |"boolean|cartesian_point|car mv_last |field |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_max |field |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_median |number |"double|integer|long|unsigned_long" |Multivalue expression. +mv_median_abso|number |"double|integer|long|unsigned_long" |Multivalue expression. mv_min |field |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_percentile |[number, percentile] |["double|integer|long", "double|integer|long"] |[Multivalue expression., The percentile to calculate. Must be a number between 0 and 100. Numbers out of range will return a null instead.] mv_pseries_wei|[number, p] |[double, double] |[Multivalue expression., It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum.] @@ -301,6 +303,7 @@ mv_first |Converts a multivalued expression into a single valued column con mv_last |Converts a multivalue expression into a single valued column containing the last value. This is most useful when reading from a function that emits multivalued columns in a known order like <>. The order that <> are read from underlying storage is not guaranteed. It is *frequently* ascending, but don't rely on that. If you need the maximum value use <> instead of `MV_LAST`. `MV_MAX` has optimizations for sorted values so there isn't a performance benefit to `MV_LAST`. mv_max |Converts a multivalued expression into a single valued column containing the maximum value. mv_median |Converts a multivalued field into a single valued field containing the median value. +mv_median_abso|"Converts a multivalued field into a single valued field containing the median absolute deviation. It is calculated as the median of each data point's deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`." mv_min |Converts a multivalued expression into a single valued column containing the minimum value. mv_percentile |Converts a multivalued field into a single valued field containing the value at which a certain percentage of observed values occur. mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. @@ -427,6 +430,7 @@ mv_first |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|g mv_last |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version"|false |false |false mv_max |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false mv_median |"double|integer|long|unsigned_long" |false |false |false +mv_median_abso|"double|integer|long|unsigned_long" |false |false |false mv_min |"boolean|date|date_nanos|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false mv_percentile |"double|integer|long" |[false, false] |false |false mv_pseries_wei|"double" |[false, false] |false |false @@ -508,5 +512,5 @@ countFunctions#[skip:-8.15.99] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long -115 | 115 | 115 +116 | 116 | 116 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec new file mode 100644 index 0000000000000..f648fc5630469 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_median_absolute_deviation.csv-spec @@ -0,0 +1,84 @@ +example +required_capability: fn_mv_median_absolute_deviation + +// tag::example[] +ROW values = [0, 2, 5, 6] +| EVAL median_absolute_deviation = MV_MEDIAN_ABSOLUTE_DEVIATION(values), median = MV_MEDIAN(values) +// end::example[] +; + +// tag::example-result[] +values:integer | median_absolute_deviation:integer | median:integer +[0, 2, 5, 6] | 2 | 3 +// end::example-result[] +; + +fromIndex +required_capability: fn_mv_median_absolute_deviation + +FROM employees +| WHERE emp_no IN (10001, 10002, 10007, 10009) +| SORT emp_no ASC +| EVAL + int = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.int), + long = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change.long), + double = MV_MEDIAN_ABSOLUTE_DEVIATION(salary_change) +| KEEP emp_no, salary_change, int, long, double +; + +emp_no:integer | salary_change:double | int:integer | long:long | double:double +10001 | 1.19 | 0 | 0 | 0 +10002 | [-7.23, 11.17] | 9 | 9 | 9.2 +10007 | [-7.06,0.57,1.99] | 1 | 1 | 1.42 +10009 | null | null | null | null +; + +allTypes +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6] +| EVAL + int = MV_MEDIAN_ABSOLUTE_DEVIATION(x::INTEGER), + long = MV_MEDIAN_ABSOLUTE_DEVIATION(x::LONG), + double = MV_MEDIAN_ABSOLUTE_DEVIATION(x::DOUBLE), + ul = MV_MEDIAN_ABSOLUTE_DEVIATION(x::UNSIGNED_LONG) +| KEEP int, long, double, ul +; + +int:integer | long:long | double:double | ul:unsigned_long +2 | 2 | 2 | 2 +; + +multipleExpressions +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6] +| EVAL + MV_MEDIAN_ABSOLUTE_DEVIATION(x), + a = MV_MEDIAN_ABSOLUTE_DEVIATION(x), + b = MV_MEDIAN_ABSOLUTE_DEVIATION(TO_DOUBLE([2, 5])), + c = MV_MEDIAN_ABSOLUTE_DEVIATION(CASE(true, x, [0, 1])) +; + +x:integer | MV_MEDIAN_ABSOLUTE_DEVIATION(x):integer | a:integer | b:double | c:integer +[0, 2, 5, 6] | 2 | 2 | 1.5 | 2 +; + +nullsAndFolds +required_capability: fn_mv_median_absolute_deviation + +ROW x = [0, 2, 5, 6], single = 300 +| EVAL evalNull = null / 2, evalValue = 31 + 1 +| LIMIT 1 +| EVAL + a = MV_MEDIAN_ABSOLUTE_DEVIATION(x), + b = MV_MEDIAN_ABSOLUTE_DEVIATION(single), + c = MV_MEDIAN_ABSOLUTE_DEVIATION(null), + d = MV_MEDIAN_ABSOLUTE_DEVIATION(evalNull), + e = MV_MEDIAN_ABSOLUTE_DEVIATION(evalValue) +| KEEP a, b, c, d, e +; + +a:integer | b:integer | c:null | d:integer | e:integer +2 | 0 | null | null | 0 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec index 2ac7a0cf6217a..7f4b8b24dbe63 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec @@ -130,19 +130,6 @@ m:double | p50:double | job_positions:keyword 3.9299999999999997 | 3.9299999999999997 | "Architect" ; -medianAbsoluteDeviation -// tag::median-absolute-deviation[] -FROM employees -| STATS MEDIAN(salary), MEDIAN_ABSOLUTE_DEVIATION(salary) -// end::median-absolute-deviation[] -; - -// tag::median-absolute-deviation-result[] -MEDIAN(salary):double | MEDIAN_ABSOLUTE_DEVIATION(salary):double -47003 | 10096.5 -// end::median-absolute-deviation-result[] -; - medianViaExpression from employees | stats p50 = percentile(salary_change, 25*2); @@ -170,19 +157,6 @@ median_max_salary_change:double // end::docsStatsMedianNestedExpression-result[] ; -docsStatsMADNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] -// tag::docsStatsMADNestedExpression[] -FROM employees -| STATS m_a_d_max_salary_change = MEDIAN_ABSOLUTE_DEVIATION(MV_MAX(salary_change)) -// end::docsStatsMADNestedExpression[] -; - -// tag::docsStatsMADNestedExpression-result[] -m_a_d_max_salary_change:double -5.69 -// end::docsStatsMADNestedExpression-result[] -; - docsStatsPercentileNestedExpression#[skip:-8.12.99,reason:supported in 8.13+] // tag::docsStatsPercentileNestedExpression[] FROM employees diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java new file mode 100644 index 0000000000000..7cefde819dedc --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationDoubleEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationDoubleEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationDoubleEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + double value = v.getDouble(i); + MvMedianAbsoluteDeviation.process(work, value); + } + double result = MvMedianAbsoluteDeviation.finish(work); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + double value = v.getDouble(i); + MvMedianAbsoluteDeviation.process(work, value); + } + double result = MvMedianAbsoluteDeviation.finish(work); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + double value = v.getDouble(first); + double result = MvMedianAbsoluteDeviation.single(value); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + double result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Doubles work = new MvMedianAbsoluteDeviation.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + double result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationDoubleEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationDoubleEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java new file mode 100644 index 0000000000000..76013ca1115db --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationIntEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationIntEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationIntEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + int value = v.getInt(i); + MvMedianAbsoluteDeviation.process(work, value); + } + int result = MvMedianAbsoluteDeviation.finishInts(work); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + int value = v.getInt(i); + MvMedianAbsoluteDeviation.process(work, value); + } + int result = MvMedianAbsoluteDeviation.finishInts(work); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + int value = v.getInt(first); + int result = MvMedianAbsoluteDeviation.single(value); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntBlock.Builder builder = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + int result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendInt(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + IntBlock v = (IntBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (IntVector.FixedBuilder builder = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + int result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendInt(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationIntEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationIntEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java new file mode 100644 index 0000000000000..e7883d92708b7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationLongEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationLongEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationLongEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.process(work, value); + } + long result = MvMedianAbsoluteDeviation.finish(work); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.process(work, value); + } + long result = MvMedianAbsoluteDeviation.finish(work); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.single(value); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascending(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationLongEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationLongEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java new file mode 100644 index 0000000000000..ef8781e1dc048 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationUnsignedLongEvaluator.java @@ -0,0 +1,203 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvMedianAbsoluteDeviation}. + * This class is generated. Do not edit it. + */ +public final class MvMedianAbsoluteDeviationUnsignedLongEvaluator extends AbstractMultivalueFunction.AbstractEvaluator { + public MvMedianAbsoluteDeviationUnsignedLongEvaluator(EvalOperator.ExpressionEvaluator field, + DriverContext driverContext) { + super(driverContext, field); + } + + @Override + public String name() { + return "MvMedianAbsoluteDeviation"; + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.processUnsignedLong(work, value); + } + long result = MvMedianAbsoluteDeviation.finishUnsignedLong(work); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field. + */ + @Override + public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + if (valueCount == 1) { + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + continue; + } + int end = first + valueCount; + for (int i = first; i < end; i++) { + long value = v.getLong(i); + MvMedianAbsoluteDeviation.processUnsignedLong(work, value); + } + long result = MvMedianAbsoluteDeviation.finishUnsignedLong(work); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing only single valued fields. + */ + @Override + public Block evalSingleValuedNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + assert valueCount == 1; + int first = v.getFirstValueIndex(p); + long value = v.getLong(first); + long result = MvMedianAbsoluteDeviation.singleUnsignedLong(value); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascendingUnsignedLong(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + LongBlock v = (LongBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (LongVector.FixedBuilder builder = driverContext.blockFactory().newLongVectorFixedBuilder(positionCount)) { + MvMedianAbsoluteDeviation.Longs work = new MvMedianAbsoluteDeviation.Longs(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + long result = MvMedianAbsoluteDeviation.ascendingUnsignedLong(work, v, first, valueCount); + builder.appendLong(result); + } + return builder.build().asBlock(); + } + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final EvalOperator.ExpressionEvaluator.Factory field; + + public Factory(EvalOperator.ExpressionEvaluator.Factory field) { + this.field = field; + } + + @Override + public MvMedianAbsoluteDeviationUnsignedLongEvaluator get(DriverContext context) { + return new MvMedianAbsoluteDeviationUnsignedLongEvaluator(field.get(context), context); + } + + @Override + public String toString() { + return "MvMedianAbsoluteDeviation[field=" + field + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java index c3ea505a29e88..e3b539d8210aa 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianDoubleEvaluator.java @@ -32,6 +32,9 @@ public String name() { */ @Override public Block evalNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { @@ -60,6 +63,9 @@ public Block evalNullable(Block fieldVal) { */ @Override public Block evalNotNullable(Block fieldVal) { + if (fieldVal.mvSortedAscending()) { + return evalAscendingNotNullable(fieldVal); + } DoubleBlock v = (DoubleBlock) fieldVal; int positionCount = v.getPositionCount(); try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { @@ -79,6 +85,46 @@ public Block evalNotNullable(Block fieldVal) { } } + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + MvMedian.Doubles work = new MvMedian.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + if (valueCount == 0) { + builder.appendNull(); + continue; + } + int first = v.getFirstValueIndex(p); + double result = MvMedian.ascending(v, first, valueCount); + builder.appendDouble(result); + } + return builder.build(); + } + } + + /** + * Evaluate blocks containing at least one multivalued field and all multivalued fields are in ascending order. + */ + private Block evalAscendingNotNullable(Block fieldVal) { + DoubleBlock v = (DoubleBlock) fieldVal; + int positionCount = v.getPositionCount(); + try (DoubleVector.FixedBuilder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + MvMedian.Doubles work = new MvMedian.Doubles(); + for (int p = 0; p < positionCount; p++) { + int valueCount = v.getValueCount(p); + int first = v.getFirstValueIndex(p); + double result = MvMedian.ascending(v, first, valueCount); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { private final EvalOperator.ExpressionEvaluator.Factory field; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 120323ebeb7a6..842b3dc30464c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -37,6 +37,11 @@ public enum Cap { */ FN_MV_APPEND, + /** + * Support for {@code MV_MEDIAN_ABSOLUTE_DEVIATION} function. + */ + FN_MV_MEDIAN_ABSOLUTE_DEVIATION, + /** * Support for {@code MV_PERCENTILE} function. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 0d50623fe77eb..99794dd5875f8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -94,6 +94,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvLast; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviation; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPSeriesWeightedSum; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPercentile; @@ -363,6 +364,7 @@ private FunctionDefinition[][] functions() { def(MvLast.class, MvLast::new, "mv_last"), def(MvMax.class, MvMax::new, "mv_max"), def(MvMedian.class, MvMedian::new, "mv_median"), + def(MvMedianAbsoluteDeviation.class, MvMedianAbsoluteDeviation::new, "mv_median_absolute_deviation"), def(MvMin.class, MvMin::new, "mv_min"), def(MvPercentile.class, MvPercentile::new, "mv_percentile"), def(MvPSeriesWeightedSum.class, MvPSeriesWeightedSum::new, "mv_pseries_weighted_sum"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java index 46661e96b1d48..23a6b23a35cde 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java @@ -16,14 +16,17 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.SurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviation; import java.io.IOException; import java.util.List; -public class MedianAbsoluteDeviation extends NumericAggregate { +public class MedianAbsoluteDeviation extends NumericAggregate implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "MedianAbsoluteDeviation", @@ -50,13 +53,13 @@ public class MedianAbsoluteDeviation extends NumericAggregate { ====""", isAggregation = true, examples = { - @Example(file = "stats_percentile", tag = "median-absolute-deviation"), + @Example(file = "median_absolute_deviation", tag = "median-absolute-deviation"), @Example( description = "The expression can use inline functions. For example, to calculate the the " + "median absolute deviation of the maximum values of a multivalued column, first " + "use `MV_MAX` to get the maximum value per row, and use the result with the " + "`MEDIAN_ABSOLUTE_DEVIATION` function", - file = "stats_percentile", + file = "median_absolute_deviation", tag = "docsStatsMADNestedExpression" ), } ) @@ -97,4 +100,16 @@ protected AggregatorFunctionSupplier intSupplier(List inputChannels) { protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { return new MedianAbsoluteDeviationDoubleAggregatorFunctionSupplier(inputChannels); } + + @Override + public Expression surrogate() { + var s = source(); + var field = field(); + + if (field.foldable()) { + return new MvMedianAbsoluteDeviation(s, new ToDouble(s, field)); + } + + return null; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java index cb0f9fdd8d5db..998a1815cbada 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java @@ -43,6 +43,7 @@ public static List getNamedWriteables() { MvLast.ENTRY, MvMax.ENTRY, MvMedian.ENTRY, + MvMedianAbsoluteDeviation.ENTRY, MvMin.ENTRY, MvPercentile.ENTRY, MvPSeriesWeightedSum.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java index e9e6899117805..42510a5685de3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedian.java @@ -36,7 +36,7 @@ import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.unsignedLongToBigInteger; /** - * Reduce a multivalued field to a single valued field containing the average value. + * Reduce a multivalued field to a single valued field containing the median of the values. */ public class MvMedian extends AbstractMultivalueFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "MvMedian", MvMedian::new); @@ -106,7 +106,7 @@ static class Doubles { public int count; } - @MvEvaluator(extraName = "Double", finish = "finish") + @MvEvaluator(extraName = "Double", finish = "finish", ascending = "ascending") static void process(Doubles doubles, double v) { if (doubles.values.length < doubles.count + 1) { doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java new file mode 100644 index 0000000000000..9bdfd1a2ccafc --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviation.java @@ -0,0 +1,402 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.apache.lucene.util.ArrayUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.ann.MvEvaluator; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.NumericUtils; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.isRepresentable; +import static org.elasticsearch.xpack.esql.core.util.NumericUtils.unsignedLongSubtractExact; + +/** + * Reduce a multivalued field to a single valued field containing the median absolute deviation of the values. + */ +public class MvMedianAbsoluteDeviation extends AbstractMultivalueFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MvMedianAbsoluteDeviation", + MvMedianAbsoluteDeviation::new + ); + + @FunctionInfo( + returnType = { "double", "integer", "long", "unsigned_long" }, + description = "Converts a multivalued field into a single valued field containing the median absolute deviation." + + "\n\n" + + "It is calculated as the median of each data point's deviation from the median of " + + "the entire sample. That is, for a random variable `X`, the median absolute " + + "deviation is `median(|median(X) - X|)`.", + note = "If the field has an even number of values, " + + "the medians will be calculated as the average of the middle two values. " + + "If the value is not a floating point number, the averages are rounded towards 0.", + examples = @Example(file = "mv_median_absolute_deviation", tag = "example") + ) + public MvMedianAbsoluteDeviation( + Source source, + @Param( + name = "number", + type = { "double", "integer", "long", "unsigned_long" }, + description = "Multivalue expression." + ) Expression field + ) { + super(source, field); + } + + private MvMedianAbsoluteDeviation(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveFieldType() { + return isType(field(), t -> t.isNumeric() && isRepresentable(t), sourceText(), null, "numeric"); + } + + @Override + protected ExpressionEvaluator.Factory evaluator(ExpressionEvaluator.Factory fieldEval) { + return switch (PlannerUtils.toElementType(field().dataType())) { + case DOUBLE -> new MvMedianAbsoluteDeviationDoubleEvaluator.Factory(fieldEval); + case INT -> new MvMedianAbsoluteDeviationIntEvaluator.Factory(fieldEval); + case LONG -> field().dataType() == DataType.UNSIGNED_LONG + ? new MvMedianAbsoluteDeviationUnsignedLongEvaluator.Factory(fieldEval) + : new MvMedianAbsoluteDeviationLongEvaluator.Factory(fieldEval); + default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType()); + }; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new MvMedianAbsoluteDeviation(source(), newChildren.get(0)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, MvMedianAbsoluteDeviation::new, field()); + } + + static class Longs { + public long[] values = new long[2]; + public int count; + } + + /** + * Evaluator for integers. + *

+ * To avoid integer overflows, we're using the {@link Longs} class to store the values. + *

+ */ + @MvEvaluator(extraName = "Int", finish = "finishInts", ascending = "ascending", single = "single") + static void process(Longs longs, int v) { + if (longs.values.length < longs.count + 1) { + longs.values = ArrayUtil.grow(longs.values, longs.count + 1); + } + longs.values[longs.count++] = v; + } + + static int finishInts(Longs longs) { + try { + long median = longMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + // We know they were ints, so we can calculate differences within a long + longs.values[i] = value > median ? value - median : median - value; + } + return Math.toIntExact(longMedianOf(longs)); + } finally { + longs.count = 0; + } + } + + /** + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. + */ + static int ascending(Longs longs, IntBlock values, int firstValue, int count) { + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median = count % 2 == 1 ? values.getInt(middle) : avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle)); + for (int i = 0; i < count; i++) { + long value = values.getInt(firstValue + i); + longs.values[i] = value > median ? value - median : median - value; + } + return Math.toIntExact(longMedianOf(longs)); + } finally { + longs.count = 0; + } + } + + static int single(int value) { + return 0; + } + + @MvEvaluator(extraName = "Long", finish = "finish", ascending = "ascending", single = "single") + static void process(Longs longs, long v) { + if (longs.values.length < longs.count + 1) { + longs.values = ArrayUtil.grow(longs.values, longs.count + 1); + } + longs.values[longs.count++] = v; + } + + static long finish(Longs longs) { + try { + long median = longMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); + } + return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs)); + } finally { + longs.count = 0; + } + } + + /** + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. + */ + static long ascending(Longs longs, LongBlock values, int firstValue, int count) { + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median = count % 2 == 1 ? values.getLong(middle) : avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + // From here, this array contains unsigned longs + longs.values[i] = unsignedDifference(value, median); + } + return NumericUtils.unsignedLongAsLongExact(unsignedLongMedianOf(longs)); + } finally { + longs.count = 0; + } + } + + static long single(long value) { + return 0L; + } + + static long longMedianOf(Longs longs) { + // TODO quickselect + Arrays.sort(longs.values, 0, longs.count); + int middle = longs.count / 2; + return longs.count % 2 == 1 ? longs.values[middle] : avgWithoutOverflow(longs.values[middle - 1], longs.values[middle]); + } + + static class Doubles { + public double[] values = new double[2]; + public int count; + } + + @MvEvaluator(extraName = "Double", finish = "finish", ascending = "ascending", single = "single") + static void process(Doubles doubles, double v) { + if (doubles.values.length < doubles.count + 1) { + doubles.values = ArrayUtil.grow(doubles.values, doubles.count + 1); + } + doubles.values[doubles.count++] = v; + } + + static double finish(Doubles doubles) { + try { + double median = doubleMedianOf(doubles); + for (int i = 0; i < doubles.count; i++) { + double value = doubles.values[i]; + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. + doubles.values[i] = value > median ? value - median : median - value; + } + return doubleMedianOf(doubles); + } finally { + doubles.count = 0; + } + } + + /** + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. + */ + static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) { + try { + if (doubles.values.length < count) { + doubles.values = ArrayUtil.grow(doubles.values, count); + } + doubles.count = count; + int middle = firstValue + count / 2; + double median = count % 2 == 1 ? values.getDouble(middle) : (values.getDouble(middle - 1) / 2 + values.getDouble(middle) / 2); + for (int i = 0; i < count; i++) { + double value = values.getDouble(firstValue + i); + // Double differences between median and the values may potentially result in +/-Infinity. + // As we use that value just to sort, the MAD should remain finite. + doubles.values[i] = value > median ? value - median : median - value; + } + return doubleMedianOf(doubles); + } finally { + doubles.count = 0; + } + } + + static double single(double value) { + return 0.; + } + + static double doubleMedianOf(Doubles doubles) { + // TODO quickselect + Arrays.sort(doubles.values, 0, doubles.count); + int middle = doubles.count / 2; + double median = doubles.count % 2 == 1 ? doubles.values[middle] : (doubles.values[middle - 1] / 2 + doubles.values[middle] / 2); + return NumericUtils.asFiniteNumber(median); + } + + @MvEvaluator( + extraName = "UnsignedLong", + finish = "finishUnsignedLong", + ascending = "ascendingUnsignedLong", + single = "singleUnsignedLong" + ) + static void processUnsignedLong(Longs longs, long v) { + process(longs, v); + } + + static long finishUnsignedLong(Longs longs) { + try { + long median = unsignedLongMedianOf(longs); + for (int i = 0; i < longs.count; i++) { + long value = longs.values[i]; + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + } + return unsignedLongMedianOf(longs); + } finally { + longs.count = 0; + } + } + + /** + * Similar to the code in `finish`, for when the values are in ascending order. The major differences are: + * - As values are sorted, we don't need to sort them for the first median calculation. + * - We take the values directly from the block instead of from the helper object. + */ + static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) { + try { + if (longs.values.length < count) { + longs.values = ArrayUtil.grow(longs.values, count); + } + longs.count = count; + int middle = firstValue + count / 2; + long median; + if (count % 2 == 1) { + median = values.getLong(middle); + } else { + median = unsignedLongAvgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle)); + } + for (int i = 0; i < count; i++) { + long value = values.getLong(firstValue + i); + longs.values[i] = value > median ? unsignedLongSubtractExact(value, median) : unsignedLongSubtractExact(median, value); + } + return unsignedLongMedianOf(longs); + } finally { + longs.count = 0; + } + } + + static long singleUnsignedLong(long value) { + return NumericUtils.ZERO_AS_UNSIGNED_LONG; + } + + static long unsignedLongMedianOf(Longs longs) { + // TODO quickselect + Arrays.sort(longs.values, 0, longs.count); + int middle = longs.count / 2; + if (longs.count % 2 == 1) { + return longs.values[middle]; + } + return unsignedLongAvgWithoutOverflow(longs.values[middle - 1], longs.values[middle]); + } + + // Utility methods + + /** + * Average two {@code int}s together without overflow. + */ + static int avgWithoutOverflow(int a, int b) { + var value = (a & b) + ((a ^ b) >> 1); + // This method rounds negative values down instead of towards zero, like (a + b) / 2 would. + // Here we rectify up if the average is negative and the two values have different parities. + return value < 0 && ((a & 1) ^ (b & 1)) == 1 ? value + 1 : value; + } + + /** + * Average two {@code long}s without any overflow. + */ + static long avgWithoutOverflow(long a, long b) { + var value = (a & b) + ((a ^ b) >> 1); + // This method rounds negative values down instead of towards zero, like (a + b) / 2 would. + // Here we rectify up if the average is negative and the two values have different parities. + return value < 0 && ((a & 1) ^ (b & 1)) == 1 ? value + 1 : value; + } + + /** + * Average two {@code unsigned long}s without any overflow. + */ + static long unsignedLongAvgWithoutOverflow(long a, long b) { + return (a >> 1) + (b >> 1) + (a & b & 1); + } + + /** + * Returns the difference between two signed long values as an unsigned long. + */ + static long unsignedDifference(long a, long b) { + if (a >= b) { + if (a < 0 || b >= 0) { + return NumericUtils.asLongUnsigned(a - b); + } + + return NumericUtils.unsignedLongSubtractExact(a, b); + } + + // Same operations, but inverted + + if (b < 0 || a >= 0) { + return NumericUtils.asLongUnsigned(b - a); + } + + return NumericUtils.unsignedLongSubtractExact(b, a); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java new file mode 100644 index 0000000000000..6a63c38c924d9 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationSerializationTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class MvMedianAbsoluteDeviationSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected MvMedianAbsoluteDeviation createTestInstance() { + return new MvMedianAbsoluteDeviation(randomSource(), randomChild()); + } + + @Override + protected MvMedianAbsoluteDeviation mutateInstance(MvMedianAbsoluteDeviation instance) throws IOException { + return new MvMedianAbsoluteDeviation( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) + ); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java new file mode 100644 index 0000000000000..b041faf6510a1 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvMedianAbsoluteDeviationTests.java @@ -0,0 +1,263 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; +import org.elasticsearch.xpack.esql.core.util.NumericUtils; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MvMedianAbsoluteDeviationTests extends AbstractMultivalueFunctionTestCase { + public MvMedianAbsoluteDeviationTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List cases = new ArrayList<>(); + doubles(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, valuesStream) -> { + var values = valuesStream.sorted().toArray(); + int middle = size / 2; + if (size % 2 == 1) { + double median = values[middle]; + return equalTo(Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle).findFirst().orElseThrow()); + } else { + double median = (values[middle - 1] + values[middle]) / 2; + return equalTo( + Arrays.stream(values).map(d -> Math.abs(d - median)).sorted().skip(middle - 1).limit(2).average().orElseThrow() + ); + } + }); + ints(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values.mapToObj(BigInteger::valueOf)); + return equalTo(mad.intValue()); + }); + longs(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values.mapToObj(BigInteger::valueOf)); + return equalTo(mad.longValue()); + }); + unsignedLongs(cases, "mv_median_absolute_deviation", "MvMedianAbsoluteDeviation", (size, values) -> { + var mad = calculateMedianAbsoluteDeviation(size, values); + return equalTo(mad); + }); + + // Simple cases + cases.addAll(makeCases(List.of(1, 2, 5), 1, true)); + cases.addAll(makeCases(List.of(1, 2), 0, false)); + cases.addAll(makeCases(List.of(-1, -2), 0, false)); + cases.addAll(makeCases(List.of(0, 2, 5, 6), 2, false)); + + // Overflow cases + cases.addAll( + overflowCasesFor( + DataType.INTEGER, + Integer.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Integer.MAX_VALUE / 2, + Integer.MAX_VALUE / 2 + 1 + ) + ); + cases.addAll( + overflowCasesFor(DataType.LONG, Long.MAX_VALUE, Long.MIN_VALUE, Long.MAX_VALUE, Long.MAX_VALUE / 2L, Long.MAX_VALUE / 2L + 1) + ); + cases.addAll( + overflowCasesFor( + DataType.DOUBLE, + Double.MAX_VALUE, + -Double.MAX_VALUE, + Double.MAX_VALUE, + Double.MAX_VALUE / 2., + Double.MAX_VALUE / 2. + ) + ); + cases.addAll( + overflowCasesFor( + DataType.UNSIGNED_LONG, + NumericUtils.asLongUnsigned(NumericUtils.UNSIGNED_LONG_MAX), + NumericUtils.ZERO_AS_UNSIGNED_LONG, + NumericUtils.UNSIGNED_LONG_MAX.divide(BigInteger.valueOf(2)), + NumericUtils.UNSIGNED_LONG_MAX.divide(BigInteger.valueOf(2)), + BigInteger.ZERO + ) + ); + + // Custom double overflow. Can't be checked in the generic overflow cases, as "MAX_DOUBLE + 1000 == MAX_DOUBLE" + cases.add( + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData( + List.of(-Double.MAX_VALUE, Double.MAX_VALUE / 4, Double.MAX_VALUE / 4), + DataType.DOUBLE, + "field" + ) + ), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(0.) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(false, cases, (v, p) -> "numeric"); + } + + /** + * Makes cases for the given data, for each type + */ + private static List makeCases(List data, Number expectedResult, boolean withDoubles) { + var types = new ArrayList<>(List.of(DataType.INTEGER, DataType.LONG)); + if (withDoubles) { + types.add(DataType.DOUBLE); + } + if (data.stream().noneMatch(d -> d.doubleValue() < 0)) { + types.add(DataType.UNSIGNED_LONG); + } + return types.stream().map(type -> { + var convertedData = data.stream().map(d -> { + var convertedValue = DataTypeConverter.convert(d, type); + if (convertedValue instanceof BigInteger bi) { + return NumericUtils.asLongUnsigned(bi); + } + return convertedValue; + }).toList(); + + return new TestCaseSupplier( + "<" + convertedData + "> (" + type + ")", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(convertedData, type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(DataTypeConverter.convert(expectedResult, type)) + ) + ); + }).toList(); + } + + private static List overflowCasesFor( + DataType type, + Number max, + Number min, + Number maxMinMad, + Number maxZeroMad, + Number minZeroMad + ) { + var zeroExpected = DataTypeConverter.convert(0, type); + var zeroValue = type == DataType.UNSIGNED_LONG ? NumericUtils.ZERO_AS_UNSIGNED_LONG : zeroExpected; + var oneThousandValue = type == DataType.UNSIGNED_LONG + ? NumericUtils.asLongUnsigned(BigInteger.valueOf(1000)) + : DataTypeConverter.convert(1000, type); + + var typeName = type.name().toLowerCase(Locale.ROOT); + + return List.of( + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(max, min), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(maxMinMad) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(max, zeroValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(maxZeroMad) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(min, zeroValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(minZeroMad) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(max, max), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(zeroExpected) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(min, min), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(zeroExpected) + ) + ), + new TestCaseSupplier( + "mv_median_absolute_deviation()", + List.of(type), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(List.of(min, oneThousandValue, oneThousandValue), type, "field")), + "MvMedianAbsoluteDeviation[field=Attribute[channel=0]]", + type, + equalTo(zeroExpected) + ) + ) + ); + } + + private static BigInteger calculateMedianAbsoluteDeviation(int size, Stream valuesStream) { + var values = valuesStream.sorted().toArray(BigInteger[]::new); + int middle = size / 2; + if (size % 2 == 1) { + var median = values[middle]; + return Arrays.stream(values).map(bi -> bi.subtract(median).abs()).sorted().skip(middle).findFirst().orElseThrow(); + } else { + var median = values[middle - 1].add(values[middle]).divide(BigInteger.valueOf(2)); + return Arrays.stream(values) + .map(bi -> bi.subtract(median).abs()) + .sorted() + .skip(middle - 1) + .limit(2) + .reduce(BigInteger.ZERO, BigInteger::add) + .divide(BigInteger.valueOf(2)); + } + } + + @Override + protected Expression build(Source source, Expression field) { + return new MvMedianAbsoluteDeviation(source, field); + } +}