diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java
index d144d7601349d..d5fe1b4a697e0 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java
@@ -191,7 +191,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
);
case TOP_N_LONGS -> List.of(
- new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
+ new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md
new file mode 100644
index 0000000000000..acd2064002b44
--- /dev/null
+++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md
@@ -0,0 +1,13 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Supported function named parameters**
+
+`output_format`
+: (boolean) The output format of the categories. Defaults to regex.
+
+`similarity_threshold`
+: (boolean) The minimum percentage of token weight that must match for text to be added to the category bucket. Must be between 1 and 100. The larger the value the narrower the categories. Larger values will increase memory usage and create narrower categories. Defaults to 70.
+
+`analyzer`
+: (keyword) Analyzer used to convert the field into tokens for text categorization.
+
diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md
index ca23c1e2efc23..2e331187665f4 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md
@@ -19,5 +19,8 @@
:::{include} ../types/categorize.md
:::
+:::{include} ../functionNamedParams/categorize.md
+:::
+
:::{include} ../examples/categorize.md
:::
diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md
index 8733908754570..c013b67375a3d 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md
@@ -5,3 +5,6 @@
`field`
: Expression to categorize
+`options`
+: (Optional) Categorize additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params).
+
diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md
index 6043fbe719ff8..8ebe22b61286c 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md
@@ -2,8 +2,8 @@
**Supported types**
-| field | result |
-| --- | --- |
-| keyword | keyword |
-| text | keyword |
+| field | options | result |
+| --- | --- | --- |
+| keyword | | keyword |
+| text | | keyword |
diff --git a/docs/reference/query-languages/esql/images/functions/categorize.svg b/docs/reference/query-languages/esql/images/functions/categorize.svg
index bbb2bda7c480b..7629b9bb978ba 100644
--- a/docs/reference/query-languages/esql/images/functions/categorize.svg
+++ b/docs/reference/query-languages/esql/images/functions/categorize.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index ae0ccecf15ed7..99c255acf7268 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -340,6 +340,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00);
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
+ public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
index 1cae296f09c02..63f4d9c96bcd0 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
@@ -128,16 +128,26 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
/**
- * @param isCategorize Whether this group is a CATEGORIZE() or not.
- * May be changed in the future when more stateful grouping functions are added.
+ * Configuration for a BlockHash group spec that is doing text categorization.
*/
- public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
+ public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) {
+ public enum OutputFormat {
+ REGEX,
+ TOKENS
+ }
+ }
+
+ public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) {
public GroupSpec(int channel, ElementType elementType) {
- this(channel, elementType, false, null);
+ this(channel, elementType, null, null);
+ }
+
+ public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) {
+ this(channel, elementType, categorizeDef, null);
}
- public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
- this(channel, elementType, isCategorize, null);
+ public boolean isCategorize() {
+ return categorizeDef != null;
}
}
@@ -207,7 +217,13 @@ public static BlockHash buildCategorizeBlockHash(
int emitBatchSize
) {
if (groups.size() == 1) {
- return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
+ return new CategorizeBlockHash(
+ blockFactory,
+ groups.get(0).channel,
+ aggregatorMode,
+ groups.get(0).categorizeDef,
+ analysisRegistry
+ );
} else {
assert groups.get(0).isCategorize();
assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java
index 5e716d8c9d5ff..fcc1a7f3d271e 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java
@@ -18,7 +18,6 @@
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
-import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
@@ -47,12 +46,13 @@
*/
public class CategorizeBlockHash extends BlockHash {
- private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig
+ private static final CategorizationAnalyzerConfig DEFAULT_ANALYZER_CONFIG = CategorizationAnalyzerConfig
.buildStandardEsqlCategorizationAnalyzer();
private static final int NULL_ORD = 0;
private final int channel;
private final AggregatorMode aggregatorMode;
+ private final CategorizeDef categorizeDef;
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
private final CategorizeEvaluator evaluator;
@@ -64,28 +64,38 @@ public class CategorizeBlockHash extends BlockHash {
*/
private boolean seenNull = false;
- CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
+ CategorizeBlockHash(
+ BlockFactory blockFactory,
+ int channel,
+ AggregatorMode aggregatorMode,
+ CategorizeDef categorizeDef,
+ AnalysisRegistry analysisRegistry
+ ) {
super(blockFactory);
this.channel = channel;
this.aggregatorMode = aggregatorMode;
+ this.categorizeDef = categorizeDef;
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
- 0.70f
+ categorizeDef.similarityThreshold() / 100.0f
);
if (aggregatorMode.isInputPartial() == false) {
- CategorizationAnalyzer analyzer;
+ CategorizationAnalyzer categorizationAnalyzer;
try {
Objects.requireNonNull(analysisRegistry);
- analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
- } catch (Exception e) {
+ CategorizationAnalyzerConfig config = categorizeDef.analyzer() == null
+ ? DEFAULT_ANALYZER_CONFIG
+ : new CategorizationAnalyzerConfig.Builder().setAnalyzer(categorizeDef.analyzer()).build();
+ categorizationAnalyzer = new CategorizationAnalyzer(analysisRegistry, config);
+ } catch (IOException e) {
categorizer.close();
throw new RuntimeException(e);
}
- this.evaluator = new CategorizeEvaluator(analyzer);
+ this.evaluator = new CategorizeEvaluator(categorizationAnalyzer);
} else {
this.evaluator = null;
}
@@ -114,7 +124,7 @@ public IntVector nonEmpty() {
@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
- return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
+ return new Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
}
@Override
@@ -222,7 +232,7 @@ private Block buildFinalBlock() {
try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
result.appendNull();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
- scratch.copyChars(category.getRegex());
+ scratch.copyChars(getKeyString(category));
result.appendBytesRef(scratch.get());
scratch.clear();
}
@@ -232,7 +242,7 @@ private Block buildFinalBlock() {
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
- scratch.copyChars(category.getRegex());
+ scratch.copyChars(getKeyString(category));
result.appendBytesRef(scratch.get());
scratch.clear();
}
@@ -240,6 +250,13 @@ private Block buildFinalBlock() {
}
}
+ private String getKeyString(SerializableTokenListCategory category) {
+ return switch (categorizeDef.outputFormat()) {
+ case REGEX -> category.getRegex();
+ case TOKENS -> category.getKeyTokensString();
+ };
+ }
+
/**
* Similar implementation to an Evaluator.
*/
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java
index 20874cb10ceb8..bb5f0dee8ca2d 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java
@@ -56,6 +56,8 @@ public class CategorizePackedValuesBlockHash extends BlockHash {
int emitBatchSize
) {
super(blockFactory);
+ assert specs.get(0).categorizeDef() != null;
+
this.specs = specs;
this.aggregatorMode = aggregatorMode;
blocks = new Block[specs.size()];
@@ -68,7 +70,13 @@ public class CategorizePackedValuesBlockHash extends BlockHash {
boolean success = false;
try {
- categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
+ categorizeBlockHash = new CategorizeBlockHash(
+ blockFactory,
+ specs.get(0).channel(),
+ aggregatorMode,
+ specs.get(0).categorizeDef(),
+ analysisRegistry
+ );
packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
success = true;
} finally {
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
index 842952f9ef8bd..9ce086307acee 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
@@ -76,7 +76,13 @@ private void initAnalysisRegistry() throws IOException {
).getAnalysisRegistry();
}
+ private BlockHash.CategorizeDef getCategorizeDef() {
+ return new BlockHash.CategorizeDef(null, randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), 70);
+ }
+
public void testCategorizeRaw() {
+ BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
+
final Page page;
boolean withNull = randomBoolean();
final int positions = 7 + (withNull ? 1 : 0);
@@ -98,7 +104,7 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}
- try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
+ try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -137,7 +143,10 @@ public void close() {
}
});
- assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
+ switch (categorizeDef.outputFormat()) {
+ case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
+ case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected");
+ }
}
} finally {
page.releaseBlocks();
@@ -145,6 +154,8 @@ public void close() {
}
public void testCategorizeRawMultivalue() {
+ BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
+
final Page page;
boolean withNull = randomBoolean();
final int positions = 3 + (withNull ? 1 : 0);
@@ -170,7 +181,7 @@ public void testCategorizeRawMultivalue() {
page = new Page(builder.build());
}
- try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
+ try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -216,7 +227,10 @@ public void close() {
}
});
- assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
+ switch (categorizeDef.outputFormat()) {
+ case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
+ case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected");
+ }
}
} finally {
page.releaseBlocks();
@@ -224,6 +238,8 @@ public void close() {
}
public void testCategorizeIntermediate() {
+ BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
+
Page page1;
boolean withNull = randomBoolean();
int positions1 = 7 + (withNull ? 1 : 0);
@@ -259,8 +275,8 @@ public void testCategorizeIntermediate() {
// Fill intermediatePages with the intermediate state from the raw hashes
try (
- BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
- BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
+ BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry);
+ BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry);
) {
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -335,7 +351,7 @@ public void close() {
page2.releaseBlocks();
}
- try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
+ try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, categorizeDef, null)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
private void addBlock(int positionOffset, IntBlock groupIds) {
List values = IntStream.range(0, groupIds.getPositionCount())
@@ -403,14 +419,24 @@ public void close() {
}
});
- assertHashState(
- intermediateHash,
- withNull,
- ".*?Connected.+?to.*?",
- ".*?Connection.+?error.*?",
- ".*?Disconnected.*?",
- ".*?System.+?shutdown.*?"
- );
+ switch (categorizeDef.outputFormat()) {
+ case REGEX -> assertHashState(
+ intermediateHash,
+ withNull,
+ ".*?Connected.+?to.*?",
+ ".*?Connection.+?error.*?",
+ ".*?Disconnected.*?",
+ ".*?System.+?shutdown.*?"
+ );
+ case TOKENS -> assertHashState(
+ intermediateHash,
+ withNull,
+ "Connected to",
+ "Connection error",
+ "Disconnected",
+ "System shutdown"
+ );
+ }
}
} finally {
intermediatePage1.releaseBlocks();
@@ -419,6 +445,9 @@ public void close() {
}
public void testCategorize_withDriver() {
+ BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
+ BlockHash.GroupSpec groupSpec = new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef);
+
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
@@ -477,7 +506,7 @@ public void testCategorize_withDriver() {
new LocalSourceOperator(input1),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
- List.of(makeGroupSpec()),
+ List.of(groupSpec),
AggregatorMode.INITIAL,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)),
@@ -496,7 +525,7 @@ public void testCategorize_withDriver() {
new LocalSourceOperator(input2),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
- List.of(makeGroupSpec()),
+ List.of(groupSpec),
AggregatorMode.INITIAL,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)),
@@ -517,7 +546,7 @@ public void testCategorize_withDriver() {
new CannedSourceOperator(intermediateOutput.iterator()),
List.of(
new HashAggregationOperator.HashAggregationOperatorFactory(
- List.of(makeGroupSpec()),
+ List.of(groupSpec),
AggregatorMode.FINAL,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(1, 2)),
@@ -544,23 +573,36 @@ public void testCategorize_withDriver() {
sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
}
+ List keys = switch (categorizeDef.outputFormat()) {
+ case REGEX -> List.of(
+ ".*?aaazz.*?",
+ ".*?bbbzz.*?",
+ ".*?ccczz.*?",
+ ".*?dddzz.*?",
+ ".*?eeezz.*?",
+ ".*?words.+?words.+?words.+?goodbye.*?",
+ ".*?words.+?words.+?words.+?hello.*?"
+ );
+ case TOKENS -> List.of("aaazz", "bbbzz", "ccczz", "dddzz", "eeezz", "words words words goodbye", "words words words hello");
+ };
+
assertThat(
sums,
equalTo(
Map.of(
- ".*?aaazz.*?",
+ keys.get(0),
1L,
- ".*?bbbzz.*?",
+ keys.get(1),
2L,
- ".*?ccczz.*?",
+ keys.get(2),
33L,
- ".*?dddzz.*?",
+ keys.get(3),
44L,
- ".*?eeezz.*?",
+ keys.get(4),
5L,
- ".*?words.+?words.+?words.+?goodbye.*?",
+ keys.get(5),
8888L,
- ".*?words.+?words.+?words.+?hello.*?",
+ keys.get(6),
999L
)
)
@@ -569,19 +611,19 @@ public void testCategorize_withDriver() {
maxs,
equalTo(
Map.of(
- ".*?aaazz.*?",
+ keys.get(0),
1L,
- ".*?bbbzz.*?",
+ keys.get(1),
2L,
- ".*?ccczz.*?",
+ keys.get(2),
30L,
- ".*?dddzz.*?",
+ keys.get(3),
40L,
- ".*?eeezz.*?",
+ keys.get(4),
5L,
- ".*?words.+?words.+?words.+?goodbye.*?",
+ keys.get(5),
8000L,
- ".*?words.+?words.+?words.+?hello.*?",
+ keys.get(6),
900L
)
)
@@ -589,10 +631,6 @@ public void testCategorize_withDriver() {
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
}
- private BlockHash.GroupSpec makeGroupSpec() {
- return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
- }
-
private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
// Check the keys
Block[] blocks = null;
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java
index 734b0660d24a3..d0eb89eafd841 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java
@@ -74,10 +74,15 @@ public void testCategorize_withDriver() {
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
boolean withNull = randomBoolean();
boolean withMultivalues = randomBoolean();
+ BlockHash.CategorizeDef categorizeDef = new BlockHash.CategorizeDef(
+ null,
+ randomFrom(BlockHash.CategorizeDef.OutputFormat.values()),
+ 70
+ );
List groupSpecs = List.of(
- new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true),
- new BlockHash.GroupSpec(1, ElementType.INT, false)
+ new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef),
+ new BlockHash.GroupSpec(1, ElementType.INT, null)
);
LocalSourceOperator.BlockSupplier input1 = () -> {
@@ -218,8 +223,12 @@ public void testCategorize_withDriver() {
}
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
+ List keys = switch (categorizeDef.outputFormat()) {
+ case REGEX -> List.of(".*?connected.+?to.*?", ".*?connection.+?error.*?", ".*?disconnected.*?");
+ case TOKENS -> List.of("connected to", "connection error", "disconnected");
+ };
Map>> expectedResult = Map.of(
- ".*?connected.+?to.*?",
+ keys.get(0),
Map.of(
7,
Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
@@ -228,9 +237,9 @@ public void testCategorize_withDriver() {
111,
Set.of("connected to 2.1.1")
),
- ".*?connection.+?error.*?",
+ keys.get(1),
Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
- ".*?disconnected.*?",
+ keys.get(2),
Map.of(7, Set.of("disconnected"))
);
if (withNull) {
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java
index f96b9d26f075c..0ebfa7e72b805 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java
@@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer callback, Block[]..
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
List specs = new ArrayList<>(values.length);
for (int c = 0; c < values.length; c++) {
- specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c)));
+ specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c)));
}
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
/*return forcePackedHash
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java
index 106b9613d7bb2..0e9c0e33d22cd 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java
@@ -113,7 +113,7 @@ public void testTopNNullsLast() {
try (
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
- List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))),
+ List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))),
mode,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
@@ -190,7 +190,7 @@ public void testTopNNullsFirst() {
try (
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
- List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))),
+ List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3))),
mode,
List.of(
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec
index 7168ca3dc398f..be46e68a8b08a 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec
@@ -397,7 +397,7 @@ FROM sample_data
;
COUNT():long | SUM(event_duration):long | category:keyword
- 7 | 23231327 | null
+ 7 | 23231327 | null
;
on null row
@@ -800,3 +800,82 @@ COUNT():long | VALUES(str):keyword | category:keyword | str:keyword
1 | [a, b, c] | null | b
1 | [a, b, c] | null | c
;
+
+with option output_format regex
+required_capability: categorize_options
+
+FROM sample_data
+ | STATS count=COUNT()
+ BY category=CATEGORIZE(message, {"output_format": "regex"})
+ | SORT count DESC, category
+;
+
+count:long | category:keyword
+ 3 | .*?Connected.+?to.*?
+ 3 | .*?Connection.+?error.*?
+ 1 | .*?Disconnected.*?
+;
+
+with option output_format tokens
+required_capability: categorize_options
+
+FROM sample_data
+ | STATS count=COUNT()
+ BY category=CATEGORIZE(message, {"output_format": "tokens"})
+ | SORT count DESC, category
+;
+
+count:long | category:keyword
+ 3 | Connected to
+ 3 | Connection error
+ 1 | Disconnected
+;
+
+with option similarity_threshold
+required_capability: categorize_options
+
+FROM sample_data
+ | STATS count=COUNT()
+ BY category=CATEGORIZE(message, {"similarity_threshold": 99})
+ | SORT count DESC, category
+;
+
+count:long | category:keyword
+3 | .*?Connection.+?error.*?
+1 | .*?Connected.+?to.+?10\.1\.0\.1.*?
+1 | .*?Connected.+?to.+?10\.1\.0\.2.*?
+1 | .*?Connected.+?to.+?10\.1\.0\.3.*?
+1 | .*?Disconnected.*?
+;
+
+with option analyzer
+required_capability: categorize_options
+
+FROM sample_data
+ | STATS count=COUNT()
+ BY category=CATEGORIZE(message, {"analyzer": "stop"})
+ | SORT count DESC, category
+;
+
+count:long | category:keyword
+3 | .*?connected.*?
+3 | .*?connection.+?error.*?
+1 | .*?disconnected.*?
+;
+
+with all options
+required_capability: categorize_options
+
+FROM sample_data
+ | STATS count=COUNT()
+ BY category=CATEGORIZE(message, {"analyzer": "whitespace", "similarity_threshold": 100, "output_format": "tokens"})
+ | SORT count DESC, category
+;
+
+count:long | category:keyword
+3 | Connection error
+1 | Connected to 10.1.0.1
+1 | Connected to 10.1.0.2
+1 | Connected to 10.1.0.3
+1 | Disconnected
+;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index ac75811602bd6..f4777821616f1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1248,10 +1248,12 @@ public enum Cap {
* FUSE command
*/
FUSE(Build.current().isSnapshot()),
+
/**
* Support improved behavior for LIKE operator when used with index fields.
*/
LIKE_ON_INDEX_FIELDS,
+
/**
* Support avg with aggregate metric doubles
*/
@@ -1268,10 +1270,15 @@ public enum Cap {
*/
FAIL_IF_ALL_SHARDS_FAIL(Build.current().isSnapshot()),
- /*
+ /**
* Cosine vector similarity function
*/
- COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
+ COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),
+
+ /**
+ * Support for the options field of CATEGORIZE.
+ */
+ CATEGORIZE_OPTIONS;
private final boolean enabled;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java
new file mode 100644
index 0000000000000..891d8f1e6c264
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
+import org.elasticsearch.xpack.esql.core.expression.EntryExpression;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MapExpression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.core.type.DataTypeConverter;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
+
+public class Options {
+
+ public static Expression.TypeResolution resolve(
+ Expression options,
+ Source source,
+ TypeResolutions.ParamOrdinal paramOrdinal,
+ Map allowedOptions
+ ) {
+ return resolve(options, source, paramOrdinal, allowedOptions, null);
+ }
+
+ public static Expression.TypeResolution resolve(
+ Expression options,
+ Source source,
+ TypeResolutions.ParamOrdinal paramOrdinal,
+ Map allowedOptions,
+ Consumer
*/
@SupportsObservabilityTier(tier = COMPLETE)
-public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements LicenseAware {
+public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements OptionalArgument, LicenseAware {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"Categorize",
Categorize::new
);
+ private static final String ANALYZER = "analyzer";
+ private static final String OUTPUT_FORMAT = "output_format";
+ private static final String SIMILARITY_THRESHOLD = "similarity_threshold";
+
+ private static final Map ALLOWED_OPTIONS = new TreeMap<>(
+ Map.ofEntries(entry(ANALYZER, KEYWORD), entry(OUTPUT_FORMAT, KEYWORD), entry(SIMILARITY_THRESHOLD, INTEGER))
+ );
+
private final Expression field;
+ private final Expression options;
@FunctionInfo(
returnType = "keyword",
@@ -70,21 +97,56 @@ public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction
)
public Categorize(
Source source,
- @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
-
+ @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field,
+ @MapParam(
+ name = "options",
+ description = "(Optional) Categorize additional options as <>.",
+ params = {
+ @MapParam.MapParamEntry(
+ name = ANALYZER,
+ type = "keyword",
+ valueHint = { "standard" },
+ description = "Analyzer used to convert the field into tokens for text categorization."
+ ),
+ @MapParam.MapParamEntry(
+ name = OUTPUT_FORMAT,
+ type = "keyword",
+ valueHint = { "regex", "tokens" },
+ description = "The output format of the categories. Defaults to regex."
+ ),
+ @MapParam.MapParamEntry(
+ name = SIMILARITY_THRESHOLD,
+ type = "integer",
+ valueHint = { "70" },
+ description = "The minimum percentage of token weight that must match for text to be added to the category bucket. "
+ + "Must be between 1 and 100. The larger the value the narrower the categories. "
+ + "Larger values will increase memory usage and create narrower categories. Defaults to 70."
+ ), },
+ optional = true
+ ) Expression options
) {
- super(source, List.of(field));
+ super(source, options == null ? List.of(field) : List.of(field, options));
this.field = field;
+ this.options = options;
}
private Categorize(StreamInput in) throws IOException {
- this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(Expression.class),
+ in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS)
+ ? in.readOptionalNamedWriteable(Expression.class)
+ : null
+ );
}
@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field);
+ if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS)) {
+ out.writeOptionalNamedWriteable(options);
+ }
}
@Override
@@ -107,7 +169,48 @@ public Nullability nullable() {
@Override
protected TypeResolution resolveType() {
- return isString(field(), sourceText(), DEFAULT);
+ return isString(field(), sourceText(), DEFAULT).and(
+ Options.resolve(options, source(), SECOND, ALLOWED_OPTIONS, this::verifyOptions)
+ );
+ }
+
+ private void verifyOptions(Map optionsMap) {
+ if (options == null) {
+ return;
+ }
+ Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD);
+ if (similarityThreshold != null) {
+ if (similarityThreshold <= 0 || similarityThreshold > 100) {
+ throw new InvalidArgumentException(
+ format("invalid similarity threshold [{}], expecting a number between 1 and 100, inclusive", similarityThreshold)
+ );
+ }
+ }
+ String outputFormat = (String) optionsMap.get(OUTPUT_FORMAT);
+ if (outputFormat != null) {
+ try {
+ OutputFormat.valueOf(outputFormat.toUpperCase(Locale.ROOT));
+ } catch (IllegalArgumentException e) {
+ throw new InvalidArgumentException(
+ format(null, "invalid output format [{}], expecting one of [REGEX, TOKENS]", outputFormat)
+ );
+ }
+ }
+ }
+
+ public CategorizeDef categorizeDef() {
+ Map optionsMap = new HashMap<>();
+ if (options != null) {
+ Options.populateMap((MapExpression) options, optionsMap, source(), SECOND, ALLOWED_OPTIONS);
+ }
+ Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD);
+ String outputFormatString = (String) optionsMap.get(OUTPUT_FORMAT);
+ OutputFormat outputFormat = outputFormatString == null ? null : OutputFormat.valueOf(outputFormatString.toUpperCase(Locale.ROOT));
+ return new CategorizeDef(
+ (String) optionsMap.get("analyzer"),
+ outputFormat == null ? REGEX : outputFormat,
+ similarityThreshold == null ? 70 : similarityThreshold
+ );
}
@Override
@@ -117,12 +220,12 @@ public DataType dataType() {
@Override
public Categorize replaceChildren(List newChildren) {
- return new Categorize(source(), newChildren.get(0));
+ return new Categorize(source(), newChildren.get(0), newChildren.size() > 1 ? newChildren.get(1) : null);
}
@Override
protected NodeInfo extends Expression> info() {
- return NodeInfo.create(this, Categorize::new, field);
+ return NodeInfo.create(this, Categorize::new, field, options);
}
public Expression field() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
index 61528521c3749..cab5ec862d7f5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
@@ -30,6 +30,7 @@
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.MapParam;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
+import org.elasticsearch.xpack.esql.expression.function.Options;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
@@ -53,10 +54,10 @@
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
-import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
@@ -198,7 +199,7 @@ public DataType dataType() {
@Override
protected TypeResolution resolveParams() {
- return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions());
+ return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS));
}
private TypeResolution resolveField() {
@@ -221,37 +222,6 @@ private TypeResolution resolveK() {
.and(isNotNull(k(), sourceText(), THIRD));
}
- private TypeResolution resolveOptions() {
- if (options() != null) {
- TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
- if (resolution.unresolved()) {
- return resolution;
- }
- // MapExpression does not have a DataType associated with it
- resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
- if (resolution.unresolved()) {
- return resolution;
- }
-
- try {
- knnQueryOptions();
- } catch (InvalidArgumentException e) {
- return new TypeResolution(e.getMessage());
- }
- }
- return TypeResolution.TYPE_RESOLVED;
- }
-
- private Map knnQueryOptions() throws InvalidArgumentException {
- if (options() == null) {
- return Map.of();
- }
-
- Map matchOptions = new HashMap<>();
- populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
- return matchOptions;
- }
-
@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
@@ -307,7 +277,7 @@ public Expression withFilters(List filterExpressions) {
private Map queryOptions() throws InvalidArgumentException {
Map options = new HashMap<>();
if (options() != null) {
- populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
+ Options.populateMap((MapExpression) options(), options, source(), FOURTH, ALLOWED_OPTIONS);
}
return options;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java
index dd7ee26aa84bd..8fe9ccc18c006 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java
@@ -10,6 +10,7 @@
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
@@ -137,13 +138,13 @@ private static Expression transformNonEvaluatableGroupingFunction(
List newChildren = new ArrayList<>(gf.children().size());
for (Expression ex : gf.children()) {
- if (ex instanceof Attribute == false) { // TODO: foldables shouldn't require eval'ing either
+ if (ex instanceof Attribute || ex instanceof MapExpression) {
+ newChildren.add(ex);
+ } else { // TODO: foldables shouldn't require eval'ing either
var alias = new Alias(ex.source(), syntheticName(ex, gf, counter++), ex, null, true);
evals.add(alias);
newChildren.add(alias.toAttribute());
childrenChanged = true;
- } else {
- newChildren.add(ex);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
index a5d19fcc3fb14..e45fe2b0e81d8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
@@ -343,8 +343,12 @@ BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
}
-
- return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null);
+ return new BlockHash.GroupSpec(
+ channel,
+ elementType(),
+ Alias.unwrap(expression) instanceof Categorize categorize ? categorize.categorizeDef() : null,
+ null
+ );
}
ElementType elementType() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index 4951479514b72..5d4260eb4ee66 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -1972,6 +1972,57 @@ public void testCategorizeWithFilteredAggregations() {
);
}
+ public void testCategorizeInvalidOptionsField() {
+ assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled());
+
+ assertEquals(
+ "1:31: second argument of [CATEGORIZE(last_name, first_name)] must be a map expression, received [first_name]",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, first_name)")
+ );
+ assertEquals(
+ "1:31: Invalid option [blah] in [CATEGORIZE(last_name, { \"blah\": 42 })], "
+ + "expected one of [analyzer, output_format, similarity_threshold]",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"blah\": 42 })")
+ );
+ }
+
+ public void testCategorizeOptionOutputFormat() {
+ assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled());
+
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"regex\" })");
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"REGEX\" })");
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"tokens\" })");
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"ToKeNs\" })");
+ assertEquals(
+ "1:31: invalid output format [blah], expecting one of [REGEX, TOKENS]",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"blah\" })")
+ );
+ assertEquals(
+ "1:31: invalid output format [42], expecting one of [REGEX, TOKENS]",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": 42 })")
+ );
+ }
+
+ public void testCategorizeOptionSimilarityThreshold() {
+ assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled());
+
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 1 })");
+ query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 100 })");
+ assertEquals(
+ "1:31: invalid similarity threshold [0], expecting a number between 1 and 100, inclusive",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 0 })")
+ );
+ assertEquals(
+ "1:31: invalid similarity threshold [101], expecting a number between 1 and 100, inclusive",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 101 })")
+ );
+ assertEquals(
+ "1:31: Invalid option [similarity_threshold] in [CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })], "
+ + "cannot cast [blah] to [integer]",
+ error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })")
+ );
+ }
+
public void testChangePoint() {
assumeTrue("change_point must be enabled", EsqlCapabilities.Cap.CHANGE_POINT.isEnabled());
var airports = AnalyzerTestUtils.analyzer(loadMapping("mapping-airports.json", "airports"));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java
index f674f9b2c3d72..97d5b8e3ece96 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java
@@ -27,7 +27,7 @@ protected List cases() {
@Override
protected Expression build(Source source, List args) {
- return new Categorize(source, args.get(0));
+ return new Categorize(source, args.get(0), args.size() > 1 ? args.get(1) : null);
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java
index f69bb7eb3e7bb..296d624ee1777 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java
@@ -61,7 +61,7 @@ public static Iterable