Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
(arguments.isEmpty() || arguments.size() == 1)
? Collections.emptyList()
: arguments.subList(1, arguments.size());
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args);
return PlanUtils.makeOver(
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public class PPLOperandTypes {
private PPLOperandTypes() {}

public static final UDFOperandMetadata NONE = UDFOperandMetadata.wrap(OperandTypes.family());
public static final UDFOperandMetadata OPTIONAL_ANY =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.ANY).or(OperandTypes.family()));
public static final UDFOperandMetadata OPTIONAL_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
Expand All @@ -43,6 +47,10 @@ private PPLOperandTypes() {}
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)));
public static final UDFOperandMetadata ANY_OPTIONAL_TIMESTAMP =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP)));
public static final UDFOperandMetadata INTEGER_INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
public static final UDFOperandMetadata STRING_STRING =
Expand Down Expand Up @@ -121,6 +129,12 @@ private PPLOperandTypes() {}
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));
public static final UDFOperandMetadata ANY_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.ANY)
.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.DATETIME))
.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.STRING)));

public static final UDFOperandMetadata DATETIME_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
Expand Down Expand Up @@ -391,4 +394,19 @@ public Void visitInputRef(RexInputRef inputRef) {
visitor.visitEach(rexNodes);
return selectedColumns;
}

/**
* Get a string representation of the argument types expressed in ExprType for error messages.
*
* @param argTypes the list of argument types as {@link RelDataType}
* @return a string in the format [type1,type2,...] representing the argument types
*/
public static String getActualSignature(List<RelDataType> argTypes) {
return "["
+ argTypes.stream()
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
.map(Objects::toString)
.collect(Collectors.joining(","))
+ "]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ public enum BuiltinFunctionName {
.put("stddev", BuiltinFunctionName.STDDEV_POP)
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
// .put("earliest", BuiltinFunctionName.EARLIEST)
// .put("latest", BuiltinFunctionName.LATEST)
.put("earliest", BuiltinFunctionName.EARLIEST)
.put("latest", BuiltinFunctionName.LATEST)
.put("distinct_count_approx", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
.put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
.put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -261,7 +260,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
Expand Down Expand Up @@ -407,25 +405,40 @@ public void registerExternalAggOperator(
aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler));
}

public void validateAggFunctionSignature(
Copy link
Collaborator

@Swiddis Swiddis Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider using an AggFunctionSignature class instead of shotgun validation

From a safety standpoint, it would be better design to create a record type and construct it with validation. Then any methods that require the function signature can take an AggFunctionSignature, and it's guaranteed that the input is valid since you can't construct the object otherwise.

Otherwise, we rely on the rest of the code to "just know" whether the input is already validated or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually using the signature coming from function registry (this is just delegating validation). It needs to be done here since we are converting some aggregate/window function to different ones, and add/remove/change attributes.

BuiltinFunctionName functionName, RexNode field, List<RexNode> argList) {
var implementation = getImplementation(functionName);
validateFunctionArgs(implementation, functionName, field, argList);
}

public RelBuilder.AggCall resolveAgg(
BuiltinFunctionName functionName,
boolean distinct,
RexNode field,
List<RexNode> argList,
CalcitePlanContext context) {
var implementation = aggExternalFunctionRegistry.get(functionName);
if (implementation == null) {
implementation = aggFunctionRegistry.get(functionName);
}
if (implementation == null) {
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
}
var implementation = getImplementation(functionName);

// Validation is done based on original argument types to generate error from user perspective.
validateFunctionArgs(implementation, functionName, field, argList);

var handler = implementation.getValue();
return handler.apply(distinct, field, argList, context);
}

static void validateFunctionArgs(
Pair<CalciteFuncSignature, AggHandler> implementation,
BuiltinFunctionName functionName,
RexNode field,
List<RexNode> argList) {
CalciteFuncSignature signature = implementation.getKey();

List<RelDataType> argTypes = new ArrayList<>();
if (field != null) {
argTypes.add(field.getType());
}
// Currently only PERCENTILE_APPROX and TAKE have additional arguments.

// Currently only PERCENTILE_APPROX, TAKE, EARLIEST, and LATEST have additional arguments.
// Their additional arguments will always come as a map of <argName, value>
List<RelDataType> additionalArgTypes =
argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).toList();
Expand All @@ -441,10 +454,20 @@ public RelBuilder.AggCall resolveAgg(
errorMessagePattern,
functionName,
signature.typeChecker().getAllowedSignatures(),
getActualSignature(argTypes)));
PlanUtils.getActualSignature(argTypes)));
}
var handler = implementation.getValue();
return handler.apply(distinct, field, argList, context);
}

private Pair<CalciteFuncSignature, AggHandler> getImplementation(
BuiltinFunctionName functionName) {
var implementation = aggExternalFunctionRegistry.get(functionName);
if (implementation == null) {
implementation = aggFunctionRegistry.get(functionName);
}
if (implementation == null) {
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
}
return implementation;
}

public RexNode resolve(final RexBuilder builder, final String functionName, RexNode... args) {
Expand Down Expand Up @@ -492,7 +515,7 @@ public RexNode resolve(
throw new ExpressionEvaluationException(
String.format(
"Cannot resolve function: %s, arguments: %s, caused by: %s",
functionName, getActualSignature(argTypes), e.getMessage()),
functionName, PlanUtils.getActualSignature(argTypes), e.getMessage()),
e);
}
StringJoiner allowedSignatures = new StringJoiner(",");
Expand All @@ -505,7 +528,7 @@ functionName, getActualSignature(argTypes), e.getMessage()),
throw new ExpressionEvaluationException(
String.format(
"%s function expects {%s}, but got %s",
functionName, allowedSignatures, getActualSignature(argTypes)));
functionName, allowedSignatures, PlanUtils.getActualSignature(argTypes)));
}

/**
Expand Down Expand Up @@ -1072,21 +1095,6 @@ void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFuncti
register(functionName, handler, typeChecker);
}

private static RexNode resolveTimeField(List<RexNode> argList, CalcitePlanContext ctx) {
if (argList.isEmpty()) {
// Try to find @timestamp field
var timestampField =
ctx.relBuilder.peek().getRowType().getField("@timestamp", false, false);
if (timestampField == null) {
throw new IllegalArgumentException(
"Default @timestamp field not found. Please specify a time field explicitly.");
}
return ctx.rexBuilder.makeInputRef(timestampField.getType(), timestampField.getIndex());
} else {
return PlanUtils.derefMapCall(argList.get(0));
}
}

void populate() {
registerOperator(MAX, SqlStdOperatorTable.MAX);
registerOperator(MIN, SqlStdOperatorTable.MIN);
Expand Down Expand Up @@ -1116,8 +1124,7 @@ void populate() {
return ctx.relBuilder.count(distinct, null, field);
}
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.COUNT.getOperandTypeChecker(), COUNT.name(), false));
wrapSqlOperandTypeChecker(PPLOperandTypes.OPTIONAL_ANY, COUNT.name(), false));

register(
PERCENTILE_APPROX,
Expand Down Expand Up @@ -1164,20 +1171,22 @@ void populate() {
register(
EARLIEST,
(distinct, field, argList, ctx) -> {
RexNode timeField = resolveTimeField(argList, ctx);
return ctx.relBuilder.aggregateCall(SqlStdOperatorTable.ARG_MIN, field, timeField);
List<RexNode> args = resolveTimeField(argList, ctx);
return UserDefinedFunctionUtils.makeAggregateCall(
SqlStdOperatorTable.ARG_MIN, List.of(field), args, ctx.relBuilder);
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.ARG_MIN.getOperandTypeChecker(), EARLIEST.name(), false));
PPLOperandTypes.ANY_OPTIONAL_TIMESTAMP, EARLIEST.name(), false));

register(
LATEST,
(distinct, field, argList, ctx) -> {
RexNode timeField = resolveTimeField(argList, ctx);
return ctx.relBuilder.aggregateCall(SqlStdOperatorTable.ARG_MAX, field, timeField);
List<RexNode> args = resolveTimeField(argList, ctx);
return UserDefinedFunctionUtils.makeAggregateCall(
SqlStdOperatorTable.ARG_MAX, List.of(field), args, ctx.relBuilder);
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.ARG_MAX.getOperandTypeChecker(), LATEST.name(), false));
PPLOperandTypes.ANY_OPTIONAL_TIMESTAMP, EARLIEST.name(), false));

// Register FIRST function - uses document order
register(
Expand All @@ -1201,19 +1210,19 @@ void populate() {
}
}

/**
* Get a string representation of the argument types expressed in ExprType for error messages.
*
* @param argTypes the list of argument types as {@link RelDataType}
* @return a string in the format [type1,type2,...] representing the argument types
*/
private static String getActualSignature(List<RelDataType> argTypes) {
return "["
+ argTypes.stream()
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
.map(Objects::toString)
.collect(Collectors.joining(","))
+ "]";
static List<RexNode> resolveTimeField(List<RexNode> argList, CalcitePlanContext ctx) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may reuse this in any function with implicit @timestamp field? #4138

Copy link
Collaborator Author

@ykmr1224 ykmr1224 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method relies on the position of time field param, and cannot be directly reused, but we should unify the logic to refer implicit(default) timestamp field.

Added tracking issue: #4275

if (argList.isEmpty()) {
// Try to find @timestamp field
var timestampField = ctx.relBuilder.peek().getRowType().getField("@timestamp", false, false);
if (timestampField == null) {
throw new IllegalArgumentException(
"Default @timestamp field not found. Please specify a time field explicitly.");
}
return List.of(
ctx.rexBuilder.makeInputRef(timestampField.getType(), timestampField.getIndex()));
} else {
return argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
}
}

/**
Expand Down Expand Up @@ -1257,6 +1266,8 @@ private static PPLTypeChecker wrapSqlOperandTypeChecker(
pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker);
} else if (typeChecker instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes());
} else if (typeChecker != null) {
pplTypeChecker = PPLTypeChecker.wrapDefault(typeChecker);
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking", functionName);
Expand Down
Loading
Loading