diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index d144d7601349d..d5fe1b4a697e0 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -191,7 +191,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S new BlockHash.GroupSpec(2, ElementType.BYTES_REF) ); case TOP_N_LONGS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT)) + new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT)) ); default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]"); }; diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md new file mode 100644 index 0000000000000..acd2064002b44 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md @@ -0,0 +1,13 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported function named parameters** + +`output_format` +: (boolean) The output format of the categories. Defaults to regex. + +`similarity_threshold` +: (boolean) The minimum percentage of token weight that must match for text to be added to the category bucket. Must be between 1 and 100. The larger the value the narrower the categories. Larger values will increase memory usage and create narrower categories. Defaults to 70. + +`analyzer` +: (keyword) Analyzer used to convert the field into tokens for text categorization. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md index ca23c1e2efc23..2e331187665f4 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md @@ -19,5 +19,8 @@ :::{include} ../types/categorize.md ::: +:::{include} ../functionNamedParams/categorize.md +::: + :::{include} ../examples/categorize.md ::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md index 8733908754570..c013b67375a3d 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md @@ -5,3 +5,6 @@ `field` : Expression to categorize +`options` +: (Optional) Categorize additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params). + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md index 6043fbe719ff8..8ebe22b61286c 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md @@ -2,8 +2,8 @@ **Supported types** -| field | result | -| --- | --- | -| keyword | keyword | -| text | keyword | +| field | options | result | +| --- | --- | --- | +| keyword | | keyword | +| text | | keyword | diff --git a/docs/reference/query-languages/esql/images/functions/categorize.svg b/docs/reference/query-languages/esql/images/functions/categorize.svg index bbb2bda7c480b..7629b9bb978ba 100644 --- a/docs/reference/query-languages/esql/images/functions/categorize.svg +++ b/docs/reference/query-languages/esql/images/functions/categorize.svg @@ -1 +1 @@ -CATEGORIZE(field) \ No newline at end of file +CATEGORIZE(field,options) \ No newline at end of file diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ae0ccecf15ed7..99c255acf7268 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -340,6 +340,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00); public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00); public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00); + public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 1cae296f09c02..63f4d9c96bcd0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -128,16 +128,26 @@ public abstract class BlockHash implements Releasable, SeenGroupIds { public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {} /** - * @param isCategorize Whether this group is a CATEGORIZE() or not. - * May be changed in the future when more stateful grouping functions are added. + * Configuration for a BlockHash group spec that is doing text categorization. */ - public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) { + public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) { + public enum OutputFormat { + REGEX, + TOKENS + } + } + + public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) { public GroupSpec(int channel, ElementType elementType) { - this(channel, elementType, false, null); + this(channel, elementType, null, null); + } + + public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) { + this(channel, elementType, categorizeDef, null); } - public GroupSpec(int channel, ElementType elementType, boolean isCategorize) { - this(channel, elementType, isCategorize, null); + public boolean isCategorize() { + return categorizeDef != null; } } @@ -207,7 +217,13 @@ public static BlockHash buildCategorizeBlockHash( int emitBatchSize ) { if (groups.size() == 1) { - return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry); + return new CategorizeBlockHash( + blockFactory, + groups.get(0).channel, + aggregatorMode, + groups.get(0).categorizeDef, + analysisRegistry + ); } else { assert groups.get(0).isCategorize(); assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java index 5e716d8c9d5ff..fcc1a7f3d271e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java @@ -18,7 +18,6 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -47,12 +46,13 @@ */ public class CategorizeBlockHash extends BlockHash { - private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig + private static final CategorizationAnalyzerConfig DEFAULT_ANALYZER_CONFIG = CategorizationAnalyzerConfig .buildStandardEsqlCategorizationAnalyzer(); private static final int NULL_ORD = 0; private final int channel; private final AggregatorMode aggregatorMode; + private final CategorizeDef categorizeDef; private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; private final CategorizeEvaluator evaluator; @@ -64,28 +64,38 @@ public class CategorizeBlockHash extends BlockHash { */ private boolean seenNull = false; - CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) { + CategorizeBlockHash( + BlockFactory blockFactory, + int channel, + AggregatorMode aggregatorMode, + CategorizeDef categorizeDef, + AnalysisRegistry analysisRegistry + ) { super(blockFactory); this.channel = channel; this.aggregatorMode = aggregatorMode; + this.categorizeDef = categorizeDef; this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())), CategorizationPartOfSpeechDictionary.getInstance(), - 0.70f + categorizeDef.similarityThreshold() / 100.0f ); if (aggregatorMode.isInputPartial() == false) { - CategorizationAnalyzer analyzer; + CategorizationAnalyzer categorizationAnalyzer; try { Objects.requireNonNull(analysisRegistry); - analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG); - } catch (Exception e) { + CategorizationAnalyzerConfig config = categorizeDef.analyzer() == null + ? DEFAULT_ANALYZER_CONFIG + : new CategorizationAnalyzerConfig.Builder().setAnalyzer(categorizeDef.analyzer()).build(); + categorizationAnalyzer = new CategorizationAnalyzer(analysisRegistry, config); + } catch (IOException e) { categorizer.close(); throw new RuntimeException(e); } - this.evaluator = new CategorizeEvaluator(analyzer); + this.evaluator = new CategorizeEvaluator(categorizationAnalyzer); } else { this.evaluator = null; } @@ -114,7 +124,7 @@ public IntVector nonEmpty() { @Override public BitArray seenGroupIds(BigArrays bigArrays) { - return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); + return new Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); } @Override @@ -222,7 +232,7 @@ private Block buildFinalBlock() { try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { result.appendNull(); for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); + scratch.copyChars(getKeyString(category)); result.appendBytesRef(scratch.get()); scratch.clear(); } @@ -232,7 +242,7 @@ private Block buildFinalBlock() { try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); + scratch.copyChars(getKeyString(category)); result.appendBytesRef(scratch.get()); scratch.clear(); } @@ -240,6 +250,13 @@ private Block buildFinalBlock() { } } + private String getKeyString(SerializableTokenListCategory category) { + return switch (categorizeDef.outputFormat()) { + case REGEX -> category.getRegex(); + case TOKENS -> category.getKeyTokensString(); + }; + } + /** * Similar implementation to an Evaluator. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java index 20874cb10ceb8..bb5f0dee8ca2d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java @@ -56,6 +56,8 @@ public class CategorizePackedValuesBlockHash extends BlockHash { int emitBatchSize ) { super(blockFactory); + assert specs.get(0).categorizeDef() != null; + this.specs = specs; this.aggregatorMode = aggregatorMode; blocks = new Block[specs.size()]; @@ -68,7 +70,13 @@ public class CategorizePackedValuesBlockHash extends BlockHash { boolean success = false; try { - categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry); + categorizeBlockHash = new CategorizeBlockHash( + blockFactory, + specs.get(0).channel(), + aggregatorMode, + specs.get(0).categorizeDef(), + analysisRegistry + ); packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize); success = true; } finally { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 842952f9ef8bd..9ce086307acee 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -76,7 +76,13 @@ private void initAnalysisRegistry() throws IOException { ).getAnalysisRegistry(); } + private BlockHash.CategorizeDef getCategorizeDef() { + return new BlockHash.CategorizeDef(null, randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), 70); + } + public void testCategorizeRaw() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + final Page page; boolean withNull = randomBoolean(); final int positions = 7 + (withNull ? 1 : 0); @@ -98,7 +104,7 @@ public void testCategorizeRaw() { page = new Page(builder.build()); } - try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -137,7 +143,10 @@ public void close() { } }); - assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected"); + } } } finally { page.releaseBlocks(); @@ -145,6 +154,8 @@ public void close() { } public void testCategorizeRawMultivalue() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + final Page page; boolean withNull = randomBoolean(); final int positions = 3 + (withNull ? 1 : 0); @@ -170,7 +181,7 @@ public void testCategorizeRawMultivalue() { page = new Page(builder.build()); } - try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -216,7 +227,10 @@ public void close() { } }); - assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected"); + } } } finally { page.releaseBlocks(); @@ -224,6 +238,8 @@ public void close() { } public void testCategorizeIntermediate() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + Page page1; boolean withNull = randomBoolean(); int positions1 = 7 + (withNull ? 1 : 0); @@ -259,8 +275,8 @@ public void testCategorizeIntermediate() { // Fill intermediatePages with the intermediate state from the raw hashes try ( - BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); - BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); + BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry); + BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry); ) { rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -335,7 +351,7 @@ public void close() { page2.releaseBlocks(); } - try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) { + try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, categorizeDef, null)) { intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { List values = IntStream.range(0, groupIds.getPositionCount()) @@ -403,14 +419,24 @@ public void close() { } }); - assertHashState( - intermediateHash, - withNull, - ".*?Connected.+?to.*?", - ".*?Connection.+?error.*?", - ".*?Disconnected.*?", - ".*?System.+?shutdown.*?" - ); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState( + intermediateHash, + withNull, + ".*?Connected.+?to.*?", + ".*?Connection.+?error.*?", + ".*?Disconnected.*?", + ".*?System.+?shutdown.*?" + ); + case TOKENS -> assertHashState( + intermediateHash, + withNull, + "Connected to", + "Connection error", + "Disconnected", + "System shutdown" + ); + } } } finally { intermediatePage1.releaseBlocks(); @@ -419,6 +445,9 @@ public void close() { } public void testCategorize_withDriver() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + BlockHash.GroupSpec groupSpec = new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef); + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); @@ -477,7 +506,7 @@ public void testCategorize_withDriver() { new LocalSourceOperator(input1), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.INITIAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)), @@ -496,7 +525,7 @@ public void testCategorize_withDriver() { new LocalSourceOperator(input2), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.INITIAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)), @@ -517,7 +546,7 @@ public void testCategorize_withDriver() { new CannedSourceOperator(intermediateOutput.iterator()), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.FINAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(1, 2)), @@ -544,23 +573,36 @@ public void testCategorize_withDriver() { sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i)); maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i)); } + List keys = switch (categorizeDef.outputFormat()) { + case REGEX -> List.of( + ".*?aaazz.*?", + ".*?bbbzz.*?", + ".*?ccczz.*?", + ".*?dddzz.*?", + ".*?eeezz.*?", + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?hello.*?" + ); + case TOKENS -> List.of("aaazz", "bbbzz", "ccczz", "dddzz", "eeezz", "words words words goodbye", "words words words hello"); + }; + assertThat( sums, equalTo( Map.of( - ".*?aaazz.*?", + keys.get(0), 1L, - ".*?bbbzz.*?", + keys.get(1), 2L, - ".*?ccczz.*?", + keys.get(2), 33L, - ".*?dddzz.*?", + keys.get(3), 44L, - ".*?eeezz.*?", + keys.get(4), 5L, - ".*?words.+?words.+?words.+?goodbye.*?", + keys.get(5), 8888L, - ".*?words.+?words.+?words.+?hello.*?", + keys.get(6), 999L ) ) @@ -569,19 +611,19 @@ public void testCategorize_withDriver() { maxs, equalTo( Map.of( - ".*?aaazz.*?", + keys.get(0), 1L, - ".*?bbbzz.*?", + keys.get(1), 2L, - ".*?ccczz.*?", + keys.get(2), 30L, - ".*?dddzz.*?", + keys.get(3), 40L, - ".*?eeezz.*?", + keys.get(4), 5L, - ".*?words.+?words.+?words.+?goodbye.*?", + keys.get(5), 8000L, - ".*?words.+?words.+?words.+?hello.*?", + keys.get(6), 900L ) ) @@ -589,10 +631,6 @@ public void testCategorize_withDriver() { Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); } - private BlockHash.GroupSpec makeGroupSpec() { - return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true); - } - private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) { // Check the keys Block[] blocks = null; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java index 734b0660d24a3..d0eb89eafd841 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java @@ -74,10 +74,15 @@ public void testCategorize_withDriver() { DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); boolean withNull = randomBoolean(); boolean withMultivalues = randomBoolean(); + BlockHash.CategorizeDef categorizeDef = new BlockHash.CategorizeDef( + null, + randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), + 70 + ); List groupSpecs = List.of( - new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true), - new BlockHash.GroupSpec(1, ElementType.INT, false) + new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef), + new BlockHash.GroupSpec(1, ElementType.INT, null) ); LocalSourceOperator.BlockSupplier input1 = () -> { @@ -218,8 +223,12 @@ public void testCategorize_withDriver() { } Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); + List keys = switch (categorizeDef.outputFormat()) { + case REGEX -> List.of(".*?connected.+?to.*?", ".*?connection.+?error.*?", ".*?disconnected.*?"); + case TOKENS -> List.of("connected to", "connection error", "disconnected"); + }; Map>> expectedResult = Map.of( - ".*?connected.+?to.*?", + keys.get(0), Map.of( 7, Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"), @@ -228,9 +237,9 @@ public void testCategorize_withDriver() { 111, Set.of("connected to 2.1.1") ), - ".*?connection.+?error.*?", + keys.get(1), Map.of(7, Set.of("connection error"), 42, Set.of("connection error")), - ".*?disconnected.*?", + keys.get(2), Map.of(7, Set.of("disconnected")) ); if (withNull) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java index f96b9d26f075c..0ebfa7e72b805 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java @@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer callback, Block[].. private BlockHash buildBlockHash(int emitBatchSize, Block... values) { List specs = new ArrayList<>(values.length); for (int c = 0; c < values.length; c++) { - specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c))); + specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c))); } assert forcePackedHash == false : "Packed TopN hash not implemented yet"; /*return forcePackedHash diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index 106b9613d7bb2..0e9c0e33d22cd 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -113,7 +113,7 @@ public void testTopNNullsLast() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), @@ -190,7 +190,7 @@ public void testTopNNullsFirst() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3))), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index 7168ca3dc398f..be46e68a8b08a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -397,7 +397,7 @@ FROM sample_data ; COUNT():long | SUM(event_duration):long | category:keyword - 7 | 23231327 | null + 7 | 23231327 | null ; on null row @@ -800,3 +800,82 @@ COUNT():long | VALUES(str):keyword | category:keyword | str:keyword 1 | [a, b, c] | null | b 1 | [a, b, c] | null | c ; + +with option output_format regex +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"output_format": "regex"}) + | SORT count DESC, category +; + +count:long | category:keyword + 3 | .*?Connected.+?to.*? + 3 | .*?Connection.+?error.*? + 1 | .*?Disconnected.*? +; + +with option output_format tokens +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"output_format": "tokens"}) + | SORT count DESC, category +; + +count:long | category:keyword + 3 | Connected to + 3 | Connection error + 1 | Disconnected +; + +with option similarity_threshold +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"similarity_threshold": 99}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | .*?Connection.+?error.*? +1 | .*?Connected.+?to.+?10\.1\.0\.1.*? +1 | .*?Connected.+?to.+?10\.1\.0\.2.*? +1 | .*?Connected.+?to.+?10\.1\.0\.3.*? +1 | .*?Disconnected.*? +; + +with option analyzer +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"analyzer": "stop"}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | .*?connected.*? +3 | .*?connection.+?error.*? +1 | .*?disconnected.*? +; + +with all options +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"analyzer": "whitespace", "similarity_threshold": 100, "output_format": "tokens"}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | Connection error +1 | Connected to 10.1.0.1 +1 | Connected to 10.1.0.2 +1 | Connected to 10.1.0.3 +1 | Disconnected +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index ac75811602bd6..f4777821616f1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1248,10 +1248,12 @@ public enum Cap { * FUSE command */ FUSE(Build.current().isSnapshot()), + /** * Support improved behavior for LIKE operator when used with index fields. */ LIKE_ON_INDEX_FIELDS, + /** * Support avg with aggregate metric doubles */ @@ -1268,10 +1270,15 @@ public enum Cap { */ FAIL_IF_ALL_SHARDS_FAIL(Build.current().isSnapshot()), - /* + /** * Cosine vector similarity function */ - COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()); + COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()), + + /** + * Support for the options field of CATEGORIZE. + */ + CATEGORIZE_OPTIONS; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java new file mode 100644 index 0000000000000..891d8f1e6c264 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.EntryExpression; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; + +public class Options { + + public static Expression.TypeResolution resolve( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Map allowedOptions + ) { + return resolve(options, source, paramOrdinal, allowedOptions, null); + } + + public static Expression.TypeResolution resolve( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Map allowedOptions, + Consumer> verifyOptions + ) { + if (options != null) { + Expression.TypeResolution resolution = isNotNull(options, source.text(), paramOrdinal); + if (resolution.unresolved()) { + return resolution; + } + // MapExpression does not have a DataType associated with it + resolution = isMapExpression(options, source.text(), paramOrdinal); + if (resolution.unresolved()) { + return resolution; + } + try { + Map optionsMap = new HashMap<>(); + populateMap((MapExpression) options, optionsMap, source, paramOrdinal, allowedOptions); + if (verifyOptions != null) { + verifyOptions.accept(optionsMap); + } + } catch (InvalidArgumentException e) { + return new Expression.TypeResolution(e.getMessage()); + } + } + return Expression.TypeResolution.TYPE_RESOLVED; + } + + public static void populateMap( + final MapExpression options, + final Map optionsMap, + final Source source, + final TypeResolutions.ParamOrdinal paramOrdinal, + final Map allowedOptions + ) throws InvalidArgumentException { + for (EntryExpression entry : options.entryExpressions()) { + Expression optionExpr = entry.key(); + Expression valueExpr = entry.value(); + Expression.TypeResolution resolution = isFoldable(optionExpr, source.text(), paramOrdinal).and( + isFoldable(valueExpr, source.text(), paramOrdinal) + ); + if (resolution.unresolved()) { + throw new InvalidArgumentException(resolution.message()); + } + Object optionExprLiteral = ((Literal) optionExpr).value(); + Object valueExprLiteral = ((Literal) valueExpr).value(); + String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); + String optionValue = valueExprLiteral instanceof BytesRef br ? br.utf8ToString() : valueExprLiteral.toString(); + // validate the optionExpr is supported + DataType dataType = allowedOptions.get(optionName); + if (dataType == null) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, source.text(), allowedOptions.keySet()) + ); + } + try { + optionsMap.put(optionName, DataTypeConverter.convert(optionValue, dataType)); + } catch (InvalidArgumentException e) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], {}", optionName, source.text(), e.getMessage()) + ); + } + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index c347340c25050..b5378db783f46 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.fulltext; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.ShardConfig; import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator; @@ -20,20 +19,15 @@ import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; -import org.elasticsearch.xpack.esql.core.InvalidArgumentException; -import org.elasticsearch.xpack.esql.core.expression.EntryExpression; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; @@ -55,17 +49,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Predicate; -import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; @@ -409,66 +398,6 @@ public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { return new LuceneQueryScoreEvaluator.Factory(shardConfigs); } - protected static void populateOptionsMap( - final MapExpression options, - final Map optionsMap, - final TypeResolutions.ParamOrdinal paramOrdinal, - final String sourceText, - final Map allowedOptions - ) throws InvalidArgumentException { - for (EntryExpression entry : options.entryExpressions()) { - Expression optionExpr = entry.key(); - Expression valueExpr = entry.value(); - TypeResolution resolution = isFoldable(optionExpr, sourceText, paramOrdinal).and( - isFoldable(valueExpr, sourceText, paramOrdinal) - ); - if (resolution.unresolved()) { - throw new InvalidArgumentException(resolution.message()); - } - Object optionExprLiteral = ((Literal) optionExpr).value(); - Object valueExprLiteral = ((Literal) valueExpr).value(); - String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); - String optionValue = valueExprLiteral instanceof BytesRef br ? br.utf8ToString() : valueExprLiteral.toString(); - // validate the optionExpr is supported - DataType dataType = allowedOptions.get(optionName); - if (dataType == null) { - throw new InvalidArgumentException( - format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, sourceText, allowedOptions.keySet()) - ); - } - try { - optionsMap.put(optionName, DataTypeConverter.convert(optionValue, dataType)); - } catch (InvalidArgumentException e) { - throw new InvalidArgumentException(format(null, "Invalid option [{}] in [{}], {}", optionName, sourceText, e.getMessage())); - } - } - } - - protected TypeResolution resolveOptions(Expression options, TypeResolutions.ParamOrdinal paramOrdinal) { - if (options != null) { - TypeResolution resolution = isNotNull(options, sourceText(), paramOrdinal); - if (resolution.unresolved()) { - return resolution; - } - // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options, sourceText(), paramOrdinal); - if (resolution.unresolved()) { - return resolution; - } - - try { - resolvedOptions(); - } catch (InvalidArgumentException e) { - return new TypeResolution(e.getMessage()); - } - } - return TypeResolution.TYPE_RESOLVED; - } - - protected Map resolvedOptions() throws InvalidArgumentException { - return Map.of(); - } - // TODO: this should likely be replaced by calls to FieldAttribute#fieldName; the MultiTypeEsField case looks // wrong if `fieldAttribute` is a subfield, e.g. `parent.child` - multiTypeEsField#getName will just return `child`. public static String getNameFromFieldAttribute(FieldAttribute fieldAttribute) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java index 743263a878552..5c5a46fd2f759 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; @@ -298,7 +299,9 @@ public final void writeTo(StreamOutput out) throws IOException { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions(options(), THIRD)).and(checkParamCompatibility()); + return resolveField().and(resolveQuery()) + .and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS)) + .and(checkParamCompatibility()); } private TypeResolution resolveField() { @@ -342,11 +345,6 @@ private TypeResolution checkParamCompatibility() { return new TypeResolution(formatIncompatibleTypesMessage(fieldType, queryType, sourceText())); } - @Override - protected Map resolvedOptions() { - return matchQueryOptions(); - } - private Map matchQueryOptions() throws InvalidArgumentException { if (options() == null) { return Map.of(LENIENT_FIELD.getPreferredName(), true); @@ -356,7 +354,7 @@ private Map matchQueryOptions() throws InvalidArgumentException // Match is lenient by default to avoid failing on incompatible types matchOptions.put(LENIENT_FIELD.getPreferredName(), true); - populateOptionsMap((MapExpression) options(), matchOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchOptions, source(), SECOND, ALLOWED_OPTIONS); return matchOptions; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java index a41a9792f7943..4ed0e16ab5b4a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; @@ -187,7 +188,7 @@ public final void writeTo(StreamOutput out) throws IOException { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions(options(), THIRD)); + return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS)); } private TypeResolution resolveField() { @@ -200,18 +201,13 @@ private TypeResolution resolveQuery() { ); } - @Override - protected Map resolvedOptions() throws InvalidArgumentException { - return matchPhraseQueryOptions(); - } - private Map matchPhraseQueryOptions() throws InvalidArgumentException { if (options() == null) { return Map.of(); } Map matchPhraseOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchPhraseOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchPhraseOptions, source(), SECOND, ALLOWED_OPTIONS); return matchPhraseOptions; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java index 1178178c432fc..3e9fed6be850d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; @@ -368,7 +369,7 @@ private Map getOptions() throws InvalidArgumentException { return options; } - Match.populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), OPTIONS); + Options.populateMap((MapExpression) options(), options, source(), THIRD, OPTIONS); return options; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java index a4c1b1f12fb56..7285f19fc5aa7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; @@ -321,18 +322,13 @@ private Map queryStringOptions() throws InvalidArgumentException } Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchOptions, source(), SECOND, ALLOWED_OPTIONS); return matchOptions; } - @Override - protected Map resolvedOptions() { - return queryStringOptions(); - } - @Override protected TypeResolution resolveParams() { - return resolveQuery().and(resolveOptions(options(), SECOND)); + return resolveQuery().and(Options.resolve(options(), source(), SECOND, ALLOWED_OPTIONS)); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 15b4621589457..75918091f9ecd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -7,13 +7,18 @@ package org.elasticsearch.xpack.esql.expression.function.grouping; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef.OutputFormat; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.xpack.esql.LicenseAware; import org.elasticsearch.xpack.esql.SupportsObservabilityTier; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -21,16 +26,29 @@ import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; +import static java.util.Map.entry; +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef.OutputFormat.REGEX; import static org.elasticsearch.xpack.esql.SupportsObservabilityTier.ObservabilityTier.COMPLETE; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; /** * Categorizes text messages. @@ -42,14 +60,23 @@ *

*/ @SupportsObservabilityTier(tier = COMPLETE) -public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements LicenseAware { +public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements OptionalArgument, LicenseAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "Categorize", Categorize::new ); + private static final String ANALYZER = "analyzer"; + private static final String OUTPUT_FORMAT = "output_format"; + private static final String SIMILARITY_THRESHOLD = "similarity_threshold"; + + private static final Map ALLOWED_OPTIONS = new TreeMap<>( + Map.ofEntries(entry(ANALYZER, KEYWORD), entry(OUTPUT_FORMAT, KEYWORD), entry(SIMILARITY_THRESHOLD, INTEGER)) + ); + private final Expression field; + private final Expression options; @FunctionInfo( returnType = "keyword", @@ -70,21 +97,56 @@ public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction ) public Categorize( Source source, - @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field - + @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field, + @MapParam( + name = "options", + description = "(Optional) Categorize additional options as <>.", + params = { + @MapParam.MapParamEntry( + name = ANALYZER, + type = "keyword", + valueHint = { "standard" }, + description = "Analyzer used to convert the field into tokens for text categorization." + ), + @MapParam.MapParamEntry( + name = OUTPUT_FORMAT, + type = "keyword", + valueHint = { "regex", "tokens" }, + description = "The output format of the categories. Defaults to regex." + ), + @MapParam.MapParamEntry( + name = SIMILARITY_THRESHOLD, + type = "integer", + valueHint = { "70" }, + description = "The minimum percentage of token weight that must match for text to be added to the category bucket. " + + "Must be between 1 and 100. The larger the value the narrower the categories. " + + "Larger values will increase memory usage and create narrower categories. Defaults to 70." + ), }, + optional = true + ) Expression options ) { - super(source, List.of(field)); + super(source, options == null ? List.of(field) : List.of(field, options)); this.field = field; + this.options = options; } private Categorize(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class)); + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS) + ? in.readOptionalNamedWriteable(Expression.class) + : null + ); } @Override public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); out.writeNamedWriteable(field); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS)) { + out.writeOptionalNamedWriteable(options); + } } @Override @@ -107,7 +169,48 @@ public Nullability nullable() { @Override protected TypeResolution resolveType() { - return isString(field(), sourceText(), DEFAULT); + return isString(field(), sourceText(), DEFAULT).and( + Options.resolve(options, source(), SECOND, ALLOWED_OPTIONS, this::verifyOptions) + ); + } + + private void verifyOptions(Map optionsMap) { + if (options == null) { + return; + } + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); + if (similarityThreshold != null) { + if (similarityThreshold <= 0 || similarityThreshold > 100) { + throw new InvalidArgumentException( + format("invalid similarity threshold [{}], expecting a number between 1 and 100, inclusive", similarityThreshold) + ); + } + } + String outputFormat = (String) optionsMap.get(OUTPUT_FORMAT); + if (outputFormat != null) { + try { + OutputFormat.valueOf(outputFormat.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new InvalidArgumentException( + format(null, "invalid output format [{}], expecting one of [REGEX, TOKENS]", outputFormat) + ); + } + } + } + + public CategorizeDef categorizeDef() { + Map optionsMap = new HashMap<>(); + if (options != null) { + Options.populateMap((MapExpression) options, optionsMap, source(), SECOND, ALLOWED_OPTIONS); + } + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); + String outputFormatString = (String) optionsMap.get(OUTPUT_FORMAT); + OutputFormat outputFormat = outputFormatString == null ? null : OutputFormat.valueOf(outputFormatString.toUpperCase(Locale.ROOT)); + return new CategorizeDef( + (String) optionsMap.get("analyzer"), + outputFormat == null ? REGEX : outputFormat, + similarityThreshold == null ? 70 : similarityThreshold + ); } @Override @@ -117,12 +220,12 @@ public DataType dataType() { @Override public Categorize replaceChildren(List newChildren) { - return new Categorize(source(), newChildren.get(0)); + return new Categorize(source(), newChildren.get(0), newChildren.size() > 1 ? newChildren.get(1) : null); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Categorize::new, field); + return NodeInfo.create(this, Categorize::new, field, options); } public Expression field() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 61528521c3749..cab5ec862d7f5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; @@ -53,10 +54,10 @@ import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; @@ -198,7 +199,7 @@ public DataType dataType() { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions()); + return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS)); } private TypeResolution resolveField() { @@ -221,37 +222,6 @@ private TypeResolution resolveK() { .and(isNotNull(k(), sourceText(), THIRD)); } - private TypeResolution resolveOptions() { - if (options() != null) { - TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); - if (resolution.unresolved()) { - return resolution; - } - // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); - if (resolution.unresolved()) { - return resolution; - } - - try { - knnQueryOptions(); - } catch (InvalidArgumentException e) { - return new TypeResolution(e.getMessage()); - } - } - return TypeResolution.TYPE_RESOLVED; - } - - private Map knnQueryOptions() throws InvalidArgumentException { - if (options() == null) { - return Map.of(); - } - - Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); - return matchOptions; - } - @Override public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); @@ -307,7 +277,7 @@ public Expression withFilters(List filterExpressions) { private Map queryOptions() throws InvalidArgumentException { Map options = new HashMap<>(); if (options() != null) { - populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), options, source(), FOURTH, ALLOWED_OPTIONS); } return options; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java index dd7ee26aa84bd..8fe9ccc18c006 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; @@ -137,13 +138,13 @@ private static Expression transformNonEvaluatableGroupingFunction( List newChildren = new ArrayList<>(gf.children().size()); for (Expression ex : gf.children()) { - if (ex instanceof Attribute == false) { // TODO: foldables shouldn't require eval'ing either + if (ex instanceof Attribute || ex instanceof MapExpression) { + newChildren.add(ex); + } else { // TODO: foldables shouldn't require eval'ing either var alias = new Alias(ex.source(), syntheticName(ex, gf, counter++), ex, null, true); evals.add(alias); newChildren.add(alias.toAttribute()); childrenChanged = true; - } else { - newChildren.add(ex); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index a5d19fcc3fb14..e45fe2b0e81d8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -343,8 +343,12 @@ BlockHash.GroupSpec toHashGroupSpec() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } - - return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null); + return new BlockHash.GroupSpec( + channel, + elementType(), + Alias.unwrap(expression) instanceof Categorize categorize ? categorize.categorizeDef() : null, + null + ); } ElementType elementType() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 4951479514b72..5d4260eb4ee66 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1972,6 +1972,57 @@ public void testCategorizeWithFilteredAggregations() { ); } + public void testCategorizeInvalidOptionsField() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + assertEquals( + "1:31: second argument of [CATEGORIZE(last_name, first_name)] must be a map expression, received [first_name]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, first_name)") + ); + assertEquals( + "1:31: Invalid option [blah] in [CATEGORIZE(last_name, { \"blah\": 42 })], " + + "expected one of [analyzer, output_format, similarity_threshold]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"blah\": 42 })") + ); + } + + public void testCategorizeOptionOutputFormat() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"regex\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"REGEX\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"tokens\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"ToKeNs\" })"); + assertEquals( + "1:31: invalid output format [blah], expecting one of [REGEX, TOKENS]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"blah\" })") + ); + assertEquals( + "1:31: invalid output format [42], expecting one of [REGEX, TOKENS]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": 42 })") + ); + } + + public void testCategorizeOptionSimilarityThreshold() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 1 })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 100 })"); + assertEquals( + "1:31: invalid similarity threshold [0], expecting a number between 1 and 100, inclusive", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 0 })") + ); + assertEquals( + "1:31: invalid similarity threshold [101], expecting a number between 1 and 100, inclusive", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 101 })") + ); + assertEquals( + "1:31: Invalid option [similarity_threshold] in [CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })], " + + "cannot cast [blah] to [integer]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })") + ); + } + public void testChangePoint() { assumeTrue("change_point must be enabled", EsqlCapabilities.Cap.CHANGE_POINT.isEnabled()); var airports = AnalyzerTestUtils.analyzer(loadMapping("mapping-airports.json", "airports")); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java index f674f9b2c3d72..97d5b8e3ece96 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java @@ -27,7 +27,7 @@ protected List cases() { @Override protected Expression build(Source source, List args) { - return new Categorize(source, args.get(0)); + return new Categorize(source, args.get(0), args.size() > 1 ? args.get(1) : null); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java index f69bb7eb3e7bb..296d624ee1777 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java @@ -61,7 +61,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new Categorize(source, args.get(0)); + return new Categorize(source, args.get(0), args.size() > 1 ? args.get(1) : null); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java index 96e26fbd37a4c..ae30dce97ce5a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java @@ -269,7 +269,7 @@ public void testNullBucketGetsFolded() { } public void testNullCategorizeGroupingNotFolded() { - Categorize categorize = new Categorize(EMPTY, NULL); + Categorize categorize = new Categorize(EMPTY, NULL, NULL); assertEquals(categorize, foldNull(categorize)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java index 5686f3734f36e..47e20d1a56bf2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java @@ -162,6 +162,13 @@ public BytesRef[] getKeyTokens() { return Arrays.stream(keyTokenIndexes).mapToObj(index -> baseTokens[index]).toArray(BytesRef[]::new); } + public String getKeyTokensString() { + return Arrays.stream(keyTokenIndexes) + .mapToObj(index -> baseTokens[index]) + .map(BytesRef::utf8ToString) + .collect(Collectors.joining(" ")); + } + public String getRegex() { if (keyTokenIndexes.length == 0 || orderedCommonTokenBeginIndex == orderedCommonTokenEndIndex) { return ".*";