Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ protected static List<TestCaseSupplier> anyNullIsNull(
ExpectedType expectedType,
ExpectedEvaluatorToString evaluatorToString
) {
typesRequired(testCaseSuppliers);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now impossible to build test cases without types so this test doesn't do anything.

List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand Down Expand Up @@ -274,7 +273,7 @@ protected static List<TestCaseSupplier> anyNullIsNull(
}

@FunctionalInterface
protected interface PositionalErrorMessageSupplier {
public interface PositionalErrorMessageSupplier {
/**
* This interface defines functions to supply error messages for incorrect types in specific positions. Functions which have
* the same type requirements for all positions can simplify this with a lambda returning a string constant.
Expand All @@ -291,7 +290,9 @@ protected interface PositionalErrorMessageSupplier {
/**
* Adds test cases containing unsupported parameter types that assert
* that they throw type errors.
* @deprecated make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase} instead
*/
@Deprecated
protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
List<TestCaseSupplier> testCaseSuppliers,
PositionalErrorMessageSupplier positionalErrorMessageSupplier
Expand Down Expand Up @@ -331,11 +332,14 @@ protected interface TypeErrorMessageSupplier {
String apply(boolean includeOrdinal, List<Set<DataType>> validPerPosition, List<DataType> types);
}

/**
* @deprecated make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase} instead
*/
@Deprecated
protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
List<TestCaseSupplier> testCaseSuppliers,
TypeErrorMessageSupplier typeErrorMessageSupplier
) {
typesRequired(testCaseSuppliers);
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand All @@ -346,7 +350,7 @@ protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
.map(s -> s.types().size())
.collect(Collectors.toSet())
.stream()
.flatMap(count -> allPermutations(count))
.flatMap(AbstractFunctionTestCase::allPermutations)
.filter(types -> valid.contains(types) == false)
/*
* Skip any cases with more than one null. Our tests don't generate
Expand All @@ -366,10 +370,6 @@ private static List<DataType> append(List<DataType> orig, DataType extra) {
return longer;
}

protected static Stream<DataType> representable() {
return DataType.types().stream().filter(DataType::isRepresentable);
}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused.


protected static TestCaseSupplier typeErrorSupplier(
boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
Expand Down Expand Up @@ -398,7 +398,7 @@ protected static TestCaseSupplier typeErrorSupplier(
);
}

private static List<Set<DataType>> validPerPosition(Set<List<DataType>> valid) {
static List<Set<DataType>> validPerPosition(Set<List<DataType>> valid) {
int max = valid.stream().mapToInt(List::size).max().getAsInt();
List<Set<DataType>> result = new ArrayList<>(max);
for (int i = 0; i < max; i++) {
Expand Down Expand Up @@ -1327,17 +1327,6 @@ public void allMemoryReleased() {
}
}

/**
* Validate that we know the types for all the test cases already created
* @param suppliers - list of suppliers before adding in the illegal type combinations
*/
protected static void typesRequired(List<TestCaseSupplier> suppliers) {
String bad = suppliers.stream().filter(s -> s.types() == null).map(s -> s.name()).collect(Collectors.joining("\n"));
if (bad.equals("") == false) {
throw new IllegalArgumentException("types required but not found for these tests:\n" + bad);
}
}

/**
* Returns true if the current test case is for an aggregation function.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
* </p>
*
* @param entirelyNullPreservesType See {@link #anyNullIsNull(boolean, List)}
* @deprecated use {@link #parameterSuppliersFromTypedDataWithDefaultChecksNoErrors}
* and make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase}.
* It's a <strong>long</strong> faster.
*/
@Deprecated
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecks(
boolean entirelyNullPreservesType,
List<TestCaseSupplier> suppliers,
Expand All @@ -72,6 +76,23 @@ protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultCh
);
}

/**
* Converts a list of test cases into a list of parameter suppliers.
* Also, adds a default set of extra test cases.
* <p>
* Use if possible, as this method may get updated with new checks in the future.
* </p>
*
* @param entirelyNullPreservesType See {@link #anyNullIsNull(boolean, List)}
*/
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(
// TODO remove after removing parameterSuppliersFromTypedDataWithDefaultChecks rename this to that.
boolean entirelyNullPreservesType,
List<TestCaseSupplier> suppliers
) {
return parameterSuppliersFromTypedData(anyNullIsNull(entirelyNullPreservesType, randomizeBytesRefsOffset(suppliers)));
}

/**
* Converts a list of test cases into a list of parameter suppliers.
* Also, adds a default set of extra test cases.
Expand Down Expand Up @@ -364,43 +385,10 @@ public void testFold() {
}
}

public static String errorMessageStringForBinaryOperators(
boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
List<DataType> types,
PositionalErrorMessageSupplier positionalErrorMessageSupplier
) {
try {
return typeErrorMessage(includeOrdinal, validPerPosition, types, positionalErrorMessageSupplier);
} catch (IllegalStateException e) {
// This means all the positional args were okay, so the expected error is from the combination
if (types.get(0).equals(DataType.UNSIGNED_LONG)) {
return "first argument of [] is [unsigned_long] and second is ["
+ types.get(1).typeName()
+ "]. [unsigned_long] can only be operated on together with another [unsigned_long]";

}
if (types.get(1).equals(DataType.UNSIGNED_LONG)) {
return "first argument of [] is ["
+ types.get(0).typeName()
+ "] and second is [unsigned_long]. [unsigned_long] can only be operated on together with another [unsigned_long]";
}
return "first argument of [] is ["
+ (types.get(0).isNumeric() ? "numeric" : types.get(0).typeName())
+ "] so second argument must also be ["
+ (types.get(0).isNumeric() ? "numeric" : types.get(0).typeName())
+ "] but was ["
+ types.get(1).typeName()
+ "]";

}
}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved.


/**
* Adds test cases containing unsupported parameter types that immediately fail.
*/
protected static List<TestCaseSupplier> failureForCasesWithoutExamples(List<TestCaseSupplier> testCaseSuppliers) {
typesRequired(testCaseSuppliers);
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.hamcrest.Matcher;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
import static org.hamcrest.Matchers.greaterThan;

public abstract class ErrorsForCasesWithoutExamplesTestCase extends ESTestCase {
protected abstract List<TestCaseSupplier> cases();

/**
* Build the expression being tested, for the given source and list of arguments. Test classes need to implement this
* to have something to test.
*
* @param source the source
* @param args arg list from the test case, should match the length expected
* @return an expression for evaluating the function being tested on the given arguments
*/
protected abstract Expression build(Source source, List<Expression> args);

protected abstract Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature);

protected final List<TestCaseSupplier> paramsToSuppliers(Iterable<Object[]> cases) {
List<TestCaseSupplier> result = new ArrayList<>();
for (Object[] c : cases) {
if (c.length != 1) {
throw new IllegalArgumentException("weird layout for test cases");
}
TestCaseSupplier supplier = (TestCaseSupplier) c[0];
result.add(supplier);
}
return result;
}

public final void test() {
int checked = 0;
List<TestCaseSupplier> cases = cases();
Set<List<DataType>> valid = cases.stream().map(TestCaseSupplier::types).collect(Collectors.toSet());
List<Set<DataType>> validPerPosition = AbstractFunctionTestCase.validPerPosition(valid);
Iterable<List<DataType>> missingSignatures = missingSignatures(cases, valid)::iterator;
for (List<DataType> signature : missingSignatures) {
logger.debug("checking {}", signature);
List<Expression> args = new ArrayList<>(signature.size());
for (DataType type : signature) {
args.add(randomLiteral(type));
}
Expression expression = build(Source.synthetic(sourceForSignature(signature)), args);
assertTrue("expected unresolved " + expression, expression.typeResolved().unresolved());
assertThat(expression.typeResolved().message(), expectedTypeErrorMatcher(validPerPosition, signature));
checked++;
}
logger.info("checked {} signatures", checked);
assertThat("didn't check any signatures", checked, greaterThan(0));
}

private Stream<List<DataType>> missingSignatures(List<TestCaseSupplier> cases, Set<List<DataType>> valid) {
return cases.stream()
.map(s -> s.types().size())
.collect(Collectors.toSet())
.stream()
.flatMap(AbstractFunctionTestCase::allPermutations)
.filter(types -> valid.contains(types) == false)
/*
* Skip any cases with more than one null. Our tests don't generate
* the full combinatorial explosions of all nulls - just a single null.
* Hopefully <null>, <null> cases will function the same as <null>, <valid>
* cases.
*/
.filter(types -> types.stream().filter(t -> t == DataType.NULL).count() <= 1);
}

protected static String sourceForSignature(List<DataType> signature) {
StringBuilder source = new StringBuilder();
for (DataType type : signature) {
if (false == source.isEmpty()) {
source.append(", ");
}
source.append(type.typeName());
}
return source.toString();
}

/**
* Build the expected error message for an invalid type signature.
*/
protected static String typeErrorMessage(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those following along, this is copied out of AbstactFunctionTestCase but modified. I copied it because I think that once we apply this pattern globally we'll delete the function I copied it from.

boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
List<DataType> signature,
AbstractFunctionTestCase.PositionalErrorMessageSupplier expectedTypeSupplier
) {
int badArgPosition = -1;
for (int i = 0; i < signature.size(); i++) {
if (validPerPosition.get(i).contains(signature.get(i)) == false) {
badArgPosition = i;
break;
}
}
if (badArgPosition == -1) {
throw new IllegalStateException(
"Can't generate error message for these types, you probably need a custom error message function"
);
}
String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " " : "";
String source = sourceForSignature(signature);
String expectedTypeString = expectedTypeSupplier.apply(validPerPosition.get(badArgPosition), badArgPosition);
String name = signature.get(badArgPosition).typeName();
return ordinal + "argument of [" + source + "] must be [" + expectedTypeString + "], found value [] type [" + name + "]";
}

protected static String errorMessageStringForBinaryOperators(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved from AbstractFunctionTestCase and updated for the slightly different error messages this functions builds. the + source + bit.

List<Set<DataType>> validPerPosition,
List<DataType> signature,
AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
) {
try {
return typeErrorMessage(true, validPerPosition, signature, positionalErrorMessageSupplier);
} catch (IllegalStateException e) {
String source = sourceForSignature(signature);
// This means all the positional args were okay, so the expected error is from the combination
if (signature.get(0).equals(DataType.UNSIGNED_LONG)) {
return "first argument of ["
+ source
+ "] is [unsigned_long] and second is ["
+ signature.get(1).typeName()
+ "]. [unsigned_long] can only be operated on together with another [unsigned_long]";

}
if (signature.get(1).equals(DataType.UNSIGNED_LONG)) {
return "first argument of ["
+ source
+ "] is ["
+ signature.get(0).typeName()
+ "] and second is [unsigned_long]. [unsigned_long] can only be operated on together with another [unsigned_long]";
}
return "first argument of ["
+ source
+ "] is ["
+ (signature.get(0).isNumeric() ? "numeric" : signature.get(0).typeName())
+ "] so second argument must also be ["
+ (signature.get(0).isNumeric() ? "numeric" : signature.get(0).typeName())
+ "] but was ["
+ signature.get(1).typeName()
+ "]";

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.grouping;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.hamcrest.Matcher;

import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;

public class CategorizeErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
@Override
protected List<TestCaseSupplier> cases() {
return paramsToSuppliers(CategorizeTests.parameters());
}

@Override
protected Expression build(Source source, List<Expression> args) {
return new Categorize(source, args.get(0));
}

@Override
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
return equalTo(typeErrorMessage(false, validPerPosition, signature, (v, p) -> "string"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static Iterable<Object[]> parameters() {
)
);
}
return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "string");
return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}

@Override
Expand Down
Loading