Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,15 @@ PROMQL index=k8s step=1h result=(clamp_max(vector(5.0), 10));
result:double | step:datetime
5.0 | 2024-05-10T00:00:00.000Z
;

quantile
required_capability: promql_pre_tech_preview_v14
required_capability: promql_quantile
PROMQL index=k8s step=1h quantile=(round(quantile by (cluster) (0.5, quantile_over_time(0.5, network.bytes_in[1h])), 0.001))
| SORT cluster;

quantile:double | step:datetime | cluster:keyword
0.395 | 2024-05-10T00:00:00.000Z | prod
1.289 | 2024-05-10T00:00:00.000Z | qa
1.248 | 2024-05-10T00:00:00.000Z | staging
;
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,8 @@ public enum Cap {
*/
PROMQL_CLAMP(Build.current().isSnapshot()),

PROMQL_QUANTILE(PROMQL_PRE_TECH_PREVIEW_V14.isEnabled()),

/**
* KNN function adds support for k and visit_percentage options
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.logical.promql.PromqlDataType;

import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -77,13 +78,17 @@
public class PromqlFunctionRegistry {

// Common parameter definitions
private static final ParamInfo RANGE_VECTOR = ParamInfo.of("v", "range_vector", "Range vector input.");
private static final ParamInfo INSTANT_VECTOR = ParamInfo.of("v", "instant_vector", "Instant vector input.");
private static final ParamInfo SCALAR = ParamInfo.of("s", "scalar", "Scalar value.");
private static final ParamInfo QUANTILE = ParamInfo.of("φ", "scalar", "Quantile value (0 ≤ φ ≤ 1).");
private static final ParamInfo TO_NEAREST = ParamInfo.optional("to_nearest", "scalar", "Round to nearest multiple of this value.");
private static final ParamInfo MIN_SCALAR = ParamInfo.of("min", "scalar", "Minimum value.");
private static final ParamInfo MAX_SCALAR = ParamInfo.of("max", "scalar", "Maximum value.");
private static final ParamInfo RANGE_VECTOR = ParamInfo.child("v", PromqlDataType.RANGE_VECTOR, "Range vector input.");
private static final ParamInfo INSTANT_VECTOR = ParamInfo.child("v", PromqlDataType.INSTANT_VECTOR, "Instant vector input.");
private static final ParamInfo SCALAR = ParamInfo.child("s", PromqlDataType.SCALAR, "Scalar value.");
private static final ParamInfo QUANTILE = ParamInfo.of("φ", PromqlDataType.SCALAR, "Quantile value (0 ≤ φ ≤ 1).");
private static final ParamInfo TO_NEAREST = ParamInfo.optional(
"to_nearest",
PromqlDataType.SCALAR,
"Round to nearest multiple of this value."
);
private static final ParamInfo MIN_SCALAR = ParamInfo.of("min", PromqlDataType.SCALAR, "Minimum value.");
private static final ParamInfo MAX_SCALAR = ParamInfo.of("max", PromqlDataType.SCALAR, "Maximum value.");

private static final FunctionDefinition[] FUNCTION_DEFINITIONS = new FunctionDefinition[] {
//
Expand Down Expand Up @@ -390,13 +395,17 @@ public boolean validate(int paramCount) {
}
}

public record ParamInfo(String name, String type, String description, boolean optional) {
public static ParamInfo of(String name, String type, String description) {
return new ParamInfo(name, type, description, false);
public record ParamInfo(String name, PromqlDataType type, String description, boolean optional, boolean child) {
public static ParamInfo child(String name, PromqlDataType type, String description) {
return new ParamInfo(name, type, description, false, true);
}

public static ParamInfo of(String name, PromqlDataType type, String description) {
return new ParamInfo(name, type, description, false, false);
}

public static ParamInfo optional(String name, String type, String description) {
return new ParamInfo(name, type, description, true);
public static ParamInfo optional(String name, PromqlDataType type, String description) {
return new ParamInfo(name, type, description, true, false);
}
}

Expand Down Expand Up @@ -425,6 +434,20 @@ public record FunctionDefinition(
Objects.requireNonNull(description, "description cannot be null");
Objects.requireNonNull(params, "params cannot be null");
Objects.requireNonNull(examples, "examples cannot be null");
if (arity.max() != params.size()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Arity max %d does not match number of parameters %d for function %s",
arity.max(),
params.size(),
name
)
);
}
if (params.isEmpty() == false && params.stream().filter(ParamInfo::child).count() != 1) {
throw new IllegalArgumentException("If a function takes parameters, there must be exactly one child parameter");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,22 @@ private static Expression mapFunction(
}

List<Expression> extraParams = functionCall.parameters();
return PromqlFunctionRegistry.INSTANCE.buildEsqlFunction(
Expression function = PromqlFunctionRegistry.INSTANCE.buildEsqlFunction(
functionCall.functionName(),
functionCall.source(),
target,
promqlCommand.timestamp(),
window,
extraParams
);
// This can happen when trying to provide a counter to a function that doesn't support it e.g. avg_over_time on a counter
// This is essentially a bug since this limitation doesn't exist in PromQL itself.
// Throwing an error here to avoid generating invalid plans with obscure errors downstream.
Expression.TypeResolution typeResolution = function.typeResolved();
if (typeResolution.unresolved()) {
throw new QlIllegalArgumentException("Could not resolve type for function [{}]: {}", function, typeResolution.message());
}
return function;
}

private static Expression mapScalarFunction(ScalarFunction function) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,26 +433,33 @@ public LogicalPlan visitFunction(PromqlBaseParser.FunctionContext ctx) {
if (paramCount > metadata.arity().max()) {
throw new ParsingException(source, message, name, metadata.arity().max(), paramCount);
}

// child plan is always the first parameter
// TODO this is not the case for the quantile function as the first parameter is the quantile value
LogicalPlan child = params.stream().findFirst().map(param -> switch (param) {
case LogicalPlan plan -> plan;
case Literal literal -> new LiteralSelector(source, literal);
case Node n -> throw new IllegalStateException("Unexpected value: " + n);
}).orElse(null);

// PromQL expects early validation of the tree so let's do it here
PromqlDataType expectedInputType = metadata.functionType().inputType();
PromqlDataType actualInputType = PromqlPlan.getReturnType(child);
if (actualInputType != expectedInputType) {
throw new ParsingException(
source,
"expected type {} in call to function [{}], got {}",
expectedInputType,
name,
actualInputType
);
LogicalPlan child = null;
List<Expression> extraParams = new ArrayList<>(Math.max(0, params.size() - 1));
List<PromqlFunctionRegistry.ParamInfo> functionParams = metadata.params();
for (int i = 0; i < functionParams.size() && params.size() > i; i++) {
PromqlFunctionRegistry.ParamInfo expectedParam = functionParams.get(i);
LogicalPlan providedParam = switch (params.get(i)) {
case LogicalPlan plan -> plan;
case Literal literal -> new LiteralSelector(source, literal);
case Node n -> throw new IllegalStateException("Unexpected value: " + n);
};
PromqlDataType actualType = PromqlPlan.getType(providedParam);
PromqlDataType expectedType = expectedParam.type();
if (actualType != expectedType) {
throw new ParsingException(source, "expected type {} in call to function [{}], got {}", expectedType, name, actualType);
}
if (expectedParam.child()) {
child = providedParam;
} else if (providedParam instanceof LiteralSelector literalSelector) {
extraParams.add(literalSelector.literal());
} else {
throw new ParsingException(
source,
"expected literal parameter in call to function [{}], got {}",
name,
providedParam.nodeName()
);
}
}

PromqlBaseParser.GroupingContext groupingContext = ctx.grouping();
Expand All @@ -477,12 +484,8 @@ public LogicalPlan visitFunction(PromqlBaseParser.FunctionContext ctx) {
for (int i = 0; i < groupingKeys.size(); i++) {
groupings.add(new UnresolvedAttribute(source(labelListCtx.labelName(i)), groupingKeys.get(i)));
}
plan = new AcrossSeriesAggregate(source, child, name, List.of(), grouping, groupings);
plan = new AcrossSeriesAggregate(source, child, name, extraParams, grouping, groupings);
} else {
List<Expression> extraParams = params.stream()
.skip(1) // skip the first param (child)
.map(Expression.class::cast)
.toList();
plan = switch (metadata.functionType()) {
case ACROSS_SERIES_AGGREGATION -> new AcrossSeriesAggregate(
source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public interface PromqlPlan {
* @throws IllegalArgumentException if the plan is not a PromqlPlan
*/
static boolean returnsRangeVector(LogicalPlan plan) {
return getReturnType(plan) == PromqlDataType.RANGE_VECTOR;
return getType(plan) == PromqlDataType.RANGE_VECTOR;
}

/**
Expand All @@ -39,7 +39,7 @@ static boolean returnsRangeVector(LogicalPlan plan) {
* @throws IllegalArgumentException if the plan is not a PromqlPlan
*/
static boolean returnsInstantVector(LogicalPlan plan) {
return getReturnType(plan) == PromqlDataType.INSTANT_VECTOR;
return getType(plan) == PromqlDataType.INSTANT_VECTOR;
}

/**
Expand All @@ -50,10 +50,10 @@ static boolean returnsInstantVector(LogicalPlan plan) {
* @throws IllegalArgumentException if the plan is not a PromqlPlan
*/
static boolean returnsScalar(LogicalPlan plan) {
return getReturnType(plan) == PromqlDataType.SCALAR;
return getType(plan) == PromqlDataType.SCALAR;
}

static PromqlDataType getReturnType(@Nullable LogicalPlan plan) {
static PromqlDataType getType(@Nullable LogicalPlan plan) {
return switch (plan) {
case PromqlPlan promqlPlan -> promqlPlan.returnType();
case null -> null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Set;

/**
Expand Down Expand Up @@ -55,7 +56,7 @@ private void renderKibanaFunctionDefinition() throws Exception {
for (PromqlFunctionRegistry.ParamInfo param : promqlDef.params()) {
builder.startObject();
builder.field("name", param.name());
builder.field("type", param.type());
builder.field("type", param.type().name().toLowerCase(Locale.ROOT));
builder.field("optional", param.optional());
builder.field("description", param.description());
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ public void testConstantResults() {
assertConstantResult("ceil(vector(3.14159))", equalTo(4.0));
assertConstantResult("pi()", equalTo(Math.PI));
assertConstantResult("abs(vector(-1))", equalTo(1.0));
assertConstantResult("quantile(0.5, vector(1))", equalTo(1.0));
}

public void testRound() {
Expand All @@ -728,30 +729,23 @@ public void testRound() {
}

public void testClamp() {
assertScalarFunctionResult("clamp(vector(5), 0, 10)", 5.0);
assertScalarFunctionResult("clamp(vector(-5), 0, 10)", 0.0);
assertScalarFunctionResult("clamp(vector(15), 0, 10)", 10.0);
assertScalarFunctionResult("clamp(vector(0), 0, 10)", 0.0);
assertScalarFunctionResult("clamp(vector(10), 0, 10)", 10.0);
assertConstantResult("clamp(vector(5), 0, 10)", equalTo(5.0));
assertConstantResult("clamp(vector(-5), 0, 10)", equalTo(0.0));
assertConstantResult("clamp(vector(15), 0, 10)", equalTo(10.0));
assertConstantResult("clamp(vector(0), 0, 10)", equalTo(0.0));
assertConstantResult("clamp(vector(10), 0, 10)", equalTo(10.0));
}

public void testClampMin() {
assertScalarFunctionResult("clamp_min(vector(5), 0)", 5.0);
assertScalarFunctionResult("clamp_min(vector(-5), 0)", 0.0);
assertScalarFunctionResult("clamp_min(vector(0), 0)", 0.0);
assertConstantResult("clamp_min(vector(5), 0)", equalTo(5.0));
assertConstantResult("clamp_min(vector(-5), 0)", equalTo(0.0));
assertConstantResult("clamp_min(vector(0), 0)", equalTo(0.0));
}

public void testClampMax() {
assertScalarFunctionResult("clamp_max(vector(5), 10)", 5.0);
assertScalarFunctionResult("clamp_max(vector(15), 10)", 10.0);
assertScalarFunctionResult("clamp_max(vector(10), 10)", 10.0);
}

private void assertScalarFunctionResult(String promqlExpr, double expectedValue) {
var plan = planPromqlExpectNoReferences("PROMQL index=k8s step=1m result=(" + promqlExpr + ")");
Eval eval = plan.collect(Eval.class).getFirst();
Literal literal = as(eval.fields().getFirst().child(), Literal.class);
assertThat(literal.value(), equalTo(expectedValue));
assertConstantResult("clamp_max(vector(5), 10)", equalTo(5.0));
assertConstantResult("clamp_max(vector(15), 10)", equalTo(10.0));
assertConstantResult("clamp_max(vector(10), 10)", equalTo(10.0));
}

private void assertConstantResult(String query, Matcher<Double> matcher) {
Expand Down