Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
package org.elasticsearch.xpack.esql.core.expression;

public enum Nullability {
TRUE, // Whether the expression can become null
FALSE, // The expression can never become null
UNKNOWN // Cannot determine if the expression supports possible null folding
/**
* Whether the expression can become null
*/
TRUE,

/**
* The expression can never become null
*/
FALSE,

/**
* Cannot determine if the expression supports possible null folding
*/
UNKNOWN
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;

Expand Down Expand Up @@ -95,41 +95,114 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}

try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}

// TODO: randomize and try multiple pages.
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
// TODO: also test the lookup method and other stuff.
// TODO: randomize values? May give wrong results
// TODO: assert the categorizer state after adding pages.
}

public void testCategorizeRawMultivalue() {
final Page page;
boolean withNull = randomBoolean();
final int positions = 3 + (withNull ? 1 : 0);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connection error"));
builder.endPositionEntry();
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.beginPositionEntry();
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
builder.endPositionEntry();
if (withNull) {
if (randomBoolean()) {
builder.appendNull();
} else {
builder.appendBytesRef(new BytesRef(""));
}
}
page = new Page(builder.build());
}

try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
for (int i = randomInt(2); i < 3; i++) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertThat(groupIds.getFirstValueIndex(0), equalTo(0));
assertThat(groupIds.getValueCount(0), equalTo(4));
assertThat(groupIds.getFirstValueIndex(1), equalTo(4));
assertThat(groupIds.getValueCount(1), equalTo(1));
assertThat(groupIds.getFirstValueIndex(2), equalTo(5));
assertThat(groupIds.getValueCount(2), equalTo(2));

assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
}
} finally {
page.releaseBlocks();
}
}

public void testCategorizeIntermediate() {
Expand Down Expand Up @@ -226,18 +299,18 @@ public void close() {
page2.releaseBlocks();
}

try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) {
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
.collect(Collectors.toList());
if (withNull) {
assertEquals(Set.of(0, 1, 2), values);
assertEquals(List.of(0, 1, 2), values);
} else {
assertEquals(Set.of(1, 2), values);
assertEquals(List.of(1, 2), values);
}
}

Expand All @@ -252,28 +325,39 @@ public void close() {
}
});

intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
// 0 matches an existing category (Connected to ...), and the others are new.
assertEquals(Set.of(1, 3, 4), values);
}
for (int i = randomInt(2); i < 3; i++) {
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toList());
// The category IDs {1, 2, 3} should map to groups {1, 3, 4}, because
// 1 matches an existing category (Connected to ...), and the others are new.
assertEquals(List.of(3, 1, 4), values);
}

@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}
@Override
public void add(int positionOffset, IntVector groupIds) {
add(positionOffset, groupIds.asBlock());
}

@Override
public void close() {
fail("hashes should not close AddInput");
}
});
@Override
public void close() {
fail("hashes should not close AddInput");
}
});

assertHashState(
intermediateHash,
withNull,
".*?Connected.+?to.*?",
".*?Connection.+?error.*?",
".*?Disconnected.*?",
".*?System.+?shutdown.*?"
);
}
} finally {
intermediatePage1.releaseBlocks();
intermediatePage2.releaseBlocks();
Expand Down Expand Up @@ -457,4 +541,49 @@ public void testCategorize_withDriver() {
private BlockHash.GroupSpec makeGroupSpec() {
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
}

private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
// Check the keys
Block[] blocks = null;
try {
blocks = hash.getKeys();
assertThat(blocks, arrayWithSize(1));

var keysBlock = (BytesRefBlock) blocks[0];
assertThat(keysBlock.getPositionCount(), equalTo(expectedKeys.length + (withNull ? 1 : 0)));

if (withNull) {
assertTrue(keysBlock.isNull(0));
}

for (int i = 0; i < expectedKeys.length; i++) {
int position = i + (withNull ? 1 : 0);
String key = keysBlock.getBytesRef(position, new BytesRef()).utf8ToString();
assertThat(key, equalTo(expectedKeys[i]));
}
} finally {
if (blocks != null) {
Releasables.close(blocks);
}
}

// Check the nonEmpty() result
try (IntVector nonEmptyKeys = hash.nonEmpty()) {
int oneIfNull = withNull ? 1 : 0;
assertThat(nonEmptyKeys.getPositionCount(), equalTo(expectedKeys.length + oneIfNull));

for (int i = 0; i < expectedKeys.length + oneIfNull; i++) {
assertThat(nonEmptyKeys.getInt(i), equalTo(i + 1 - oneIfNull));
}
}

// Check seenGroupIds()
try (var seenGroupIds = hash.seenGroupIds(blockFactory.bigArrays())) {
assertThat(seenGroupIds.get(0), equalTo(withNull));

for (int i = 1; i <= expectedKeys.length; i++) {
assertThat(seenGroupIds.get(i), equalTo(true));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,29 @@ COUNT():long | category:keyword
7 | null
;

on const null
required_capability: categorize_v5

FROM sample_data
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(null)
| SORT category
;

COUNT():long | SUM(event_duration):long | category:keyword
7 | 23231327 | null
;

on null row
required_capability: categorize_v5

ROW message = null, str = ["a", "b", "c"]
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
;

COUNT():long | VALUES(str):keyword | category:keyword
1 | [a, b, c] | null
;

filtering out all data
required_capability: categorize_v5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
Expand Down Expand Up @@ -92,6 +93,12 @@ public boolean foldable() {
return false;
}

@Override
public Nullability nullable() {
// Both nulls and empty strings result in null values
return Nullability.TRUE;
}

@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ public Expression rule(Expression e) {
if (Expressions.isGuaranteedNull(in.value())) {
return Literal.of(in, null);
}
} else if (e instanceof Alias == false
&& e.nullable() == Nullability.TRUE
} else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE
// Categorize function stays as a STATS grouping (It isn't moved to an early EVAL like other groupings),
// so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value.
&& e instanceof Categorize == false
&& Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) {
return Literal.of(e, null);
Expand Down