Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
);
case TOP_N_LONGS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00);
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,26 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}

/**
* @param isCategorize Whether this group is a CATEGORIZE() or not.
* May be changed in the future when more stateful grouping functions are added.
* Configuration for a BlockHash group spec that is doing text categorization.
*/
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) {
public enum OutputFormat {
REGEX,
TOKENS
}
}

public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) {
public GroupSpec(int channel, ElementType elementType) {
this(channel, elementType, false, null);
this(channel, elementType, null, null);
}

public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) {
this(channel, elementType, categorizeDef, null);
}

public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
this(channel, elementType, isCategorize, null);
public boolean isCategorize() {
return categorizeDef != null;
}
}

Expand Down Expand Up @@ -207,7 +217,13 @@ public static BlockHash buildCategorizeBlockHash(
int emitBatchSize
) {
if (groups.size() == 1) {
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
return new CategorizeBlockHash(
blockFactory,
groups.get(0).channel,
aggregatorMode,
groups.get(0).categorizeDef,
analysisRegistry
);
} else {
assert groups.get(0).isCategorize();
assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
Expand Down Expand Up @@ -47,12 +46,13 @@
*/
public class CategorizeBlockHash extends BlockHash {

private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig
private static final CategorizationAnalyzerConfig DEFAULT_ANALYZER_CONFIG = CategorizationAnalyzerConfig
.buildStandardEsqlCategorizationAnalyzer();
private static final int NULL_ORD = 0;

private final int channel;
private final AggregatorMode aggregatorMode;
private final CategorizeDef categorizeDef;
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
private final CategorizeEvaluator evaluator;

Expand All @@ -64,28 +64,38 @@ public class CategorizeBlockHash extends BlockHash {
*/
private boolean seenNull = false;

CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
CategorizeBlockHash(
BlockFactory blockFactory,
int channel,
AggregatorMode aggregatorMode,
CategorizeDef categorizeDef,
AnalysisRegistry analysisRegistry
) {
super(blockFactory);

this.channel = channel;
this.aggregatorMode = aggregatorMode;
this.categorizeDef = categorizeDef;

this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
categorizeDef.similarityThreshold() / 100.0f
);

if (aggregatorMode.isInputPartial() == false) {
CategorizationAnalyzer analyzer;
CategorizationAnalyzer categorizationAnalyzer;
try {
Objects.requireNonNull(analysisRegistry);
analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
} catch (Exception e) {
CategorizationAnalyzerConfig config = categorizeDef.analyzer() == null
? DEFAULT_ANALYZER_CONFIG
: new CategorizationAnalyzerConfig.Builder().setAnalyzer(categorizeDef.analyzer()).build();
categorizationAnalyzer = new CategorizationAnalyzer(analysisRegistry, config);
} catch (IOException e) {
categorizer.close();
throw new RuntimeException(e);
}
this.evaluator = new CategorizeEvaluator(analyzer);
this.evaluator = new CategorizeEvaluator(categorizationAnalyzer);
} else {
this.evaluator = null;
}
Expand Down Expand Up @@ -114,7 +124,7 @@ public IntVector nonEmpty() {

@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
return new Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
}

@Override
Expand Down Expand Up @@ -222,7 +232,7 @@ private Block buildFinalBlock() {
try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
result.appendNull();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
scratch.copyChars(getKeyString(category));
result.appendBytesRef(scratch.get());
scratch.clear();
}
Expand All @@ -232,14 +242,21 @@ private Block buildFinalBlock() {

try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
scratch.copyChars(getKeyString(category));
result.appendBytesRef(scratch.get());
scratch.clear();
}
return result.build().asBlock();
}
}

private String getKeyString(SerializableTokenListCategory category) {
return switch (categorizeDef.outputFormat()) {
case REGEX -> category.getRegex();
case TOKENS -> category.getKeyTokensString();
};
}

/**
* Similar implementation to an Evaluator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public class CategorizePackedValuesBlockHash extends BlockHash {
int emitBatchSize
) {
super(blockFactory);
assert specs.get(0).categorizeDef() != null;

this.specs = specs;
this.aggregatorMode = aggregatorMode;
blocks = new Block[specs.size()];
Expand All @@ -68,7 +70,13 @@ public class CategorizePackedValuesBlockHash extends BlockHash {

boolean success = false;
try {
categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
categorizeBlockHash = new CategorizeBlockHash(
blockFactory,
specs.get(0).channel(),
aggregatorMode,
specs.get(0).categorizeDef(),
analysisRegistry
);
packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
success = true;
} finally {
Expand Down
Loading
Loading