diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java index c51f595ab95ce..34d05e66ad0b4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.scalar.score.Decay.DecayFunction; import org.hamcrest.Matcher; import java.math.BigDecimal; @@ -638,9 +639,10 @@ private static List intRandomTestCases() { return List.of(new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER, DataType.INTEGER, DataType.SOURCE), () -> { int randomValue = randomInt(); int randomOrigin = randomInt(); - int randomScale = randomInt(); + // scale must be > 0 + int randomScale = randomIntBetween(1, Integer.MAX_VALUE); int randomOffset = randomInt(); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = getRandomType(); double scoreScriptNumericResult = intDecayWithScoreScript( @@ -712,9 +714,9 @@ private static List longRandomTestCases() { return List.of(new TestCaseSupplier(List.of(DataType.LONG, DataType.LONG, DataType.LONG, DataType.SOURCE), () -> { long randomValue = randomLong(); long randomOrigin = randomLong(); - long randomScale = randomLong(); + long randomScale = randomLongBetween(1L, Long.MAX_VALUE); long randomOffset = randomLong(); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = randomFrom("linear", "gauss", "exp"); double scoreScriptNumericResult = longDecayWithScoreScript( @@ -782,9 +784,9 @@ private static List doubleRandomTestCases() { return List.of(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.DOUBLE, DataType.SOURCE), () -> { double randomValue = randomLong(); double randomOrigin = randomLong(); - double randomScale = randomLong(); + double randomScale = randomLongBetween(1L, Long.MAX_VALUE); double randomOffset = randomLong(); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = randomFrom("linear", "gauss", "exp"); double scoreScriptNumericResult = doubleDecayWithScoreScript( @@ -882,7 +884,7 @@ private static List geoPointRandomTestCases() { GeoPoint randomOrigin = randomGeoPoint(); String randomScale = randomDistance(); String randomOffset = randomDistance(); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = randomDecayType(); double scoreScriptNumericResult = geoPointDecayWithScoreScript( @@ -1061,7 +1063,7 @@ private static List datetimeRandomTestCases() { long randomOffsetMillis = randomNonNegativeLong() % (30L * 24 * 60 * 60 * 1000); Duration randomScale = Duration.ofMillis(randomScaleMillis); Duration randomOffset = Duration.ofMillis(randomOffsetMillis); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = randomFrom("linear", "gauss", "exp"); double scoreScriptNumericResult = datetimeDecayWithScoreScript( @@ -1148,14 +1150,14 @@ private static List dateNanosRandomTestCases() { Duration randomScale = Duration.ofMillis(randomScaleMillis); Duration randomOffset = Duration.ofMillis(randomOffsetMillis); - double randomDecay = randomDouble(); + double randomDecay = randomDecayOpenUnitInterval(); String randomType = randomFrom("linear", "gauss", "exp"); - double scoreScriptNumericResult = dateNanosDecayWithScoreScript( + double scoreScriptNumericResult = expectedDateNanosTemporalDecay( randomValue, randomOrigin, - randomScale.toMillis(), - randomOffset.toMillis(), + randomScale.toNanos(), + randomOffset.toNanos(), randomDecay, randomType ); @@ -1176,24 +1178,30 @@ private static List dateNanosRandomTestCases() { ); } - private static double dateNanosDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) { - long valueMillis = value / 1_000_000L; - long originMillis = origin / 1_000_000L; - - String originStr = String.valueOf(originMillis); - String scaleStr = scale + "ms"; - String offsetStr = offset + "ms"; - - ZonedDateTime valueDateTime = Instant.ofEpochMilli(valueMillis).atZone(ZoneId.of("UTC")); + private static double expectedDateNanosTemporalDecay( + long valueNanos, + long originNanos, + long scaleNanos, + long offsetNanos, + double decay, + String type + ) { + return decayFunctionForType(type).temporalDecay(valueNanos, originNanos, scaleNanos, offsetNanos, decay); + } + private static DecayFunction decayFunctionForType(String type) { return switch (type) { - case "linear" -> new ScoreScriptUtils.DecayDateLinear(originStr, scaleStr, offsetStr, decay).decayDateLinear(valueDateTime); - case "gauss" -> new ScoreScriptUtils.DecayDateGauss(originStr, scaleStr, offsetStr, decay).decayDateGauss(valueDateTime); - case "exp" -> new ScoreScriptUtils.DecayDateExp(originStr, scaleStr, offsetStr, decay).decayDateExp(valueDateTime); + case "linear" -> DecayFunction.LINEAR; + case "gauss" -> DecayFunction.GAUSSIAN; + case "exp" -> DecayFunction.EXPONENTIAL; default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); }; } + private static double randomDecayOpenUnitInterval() { + return randomDoubleBetween(1e-12, Math.nextDown(1.0), true); + } + private static MapExpression createOptionsMap(Object offset, Double decay, String functionType) { List keyValuePairs = new ArrayList<>();