Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 21 additions & 41 deletions velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,6 @@ class ApproxPercentileAggregate : public exec::Aggregate {
VELOX_CHECK_EQ(argIndex, args.size());
}

/// Extract percentile info: the raw data, the length and the null-ness from
/// top-level ArrayVector.
static void extractPercentiles(
const ArrayVector* arrays,
vector_size_t indexInBaseVector,
const double*& data,
vector_size_t& len,
std::vector<bool>& isNull) {
auto elements = arrays->elements()->asFlatVector<double>();
auto offset = arrays->offsetAt(indexInBaseVector);
data = elements->rawValues() + offset;
len = arrays->sizeAt(indexInBaseVector);
isNull.resize(len);
for (auto index = offset; index < offset + len; index++) {
isNull[index - offset] = elements->isNullAt(index);
}
}

void checkSetPercentile(
const SelectivityVector& rows,
const BaseVector& vec) {
Expand All @@ -536,44 +518,42 @@ class ApproxPercentileAggregate : public exec::Aggregate {
}

bool isArray;
const double* data;
vector_size_t offset;
vector_size_t len;
std::vector<bool> isNull;
if (base->typeKind() == TypeKind::DOUBLE) {
isArray = false;
data = decoded.data<double>() + baseFirstRow;
offset = rows.begin();
len = 1;
isNull = {decoded.isNullAt(rows.begin())};
} else if (base->typeKind() == TypeKind::ARRAY) {
isArray = true;
auto arrays = base->asUnchecked<ArrayVector>();
VELOX_USER_CHECK(
arrays->elements()->isFlatEncoding(),
"Only flat encoding is allowed for percentile array elements");
extractPercentiles(arrays, baseFirstRow, data, len, isNull);
decoded.decode(*arrays->elements());
offset = arrays->offsetAt(baseFirstRow);
len = arrays->sizeAt(baseFirstRow);
} else {
VELOX_USER_FAIL(
"Incorrect type for percentile: {}", base->type()->toString());
}
checkSetPercentile(isArray, data, len, isNull);
checkSetPercentile(isArray, decoded, offset, len);
}

void checkSetPercentile(
bool isArray,
const double* data,
vector_size_t len,
const std::vector<bool>& isNull) {
const DecodedVector& percentiles,
vector_size_t offset,
vector_size_t len) {
if (!percentiles_) {
VELOX_USER_CHECK_GT(len, 0, "Percentile cannot be empty");
percentiles_ = {
.values = std::vector<double>(len),
.isArray = isArray,
};
for (vector_size_t i = 0; i < len; ++i) {
VELOX_USER_CHECK(!isNull[i], "Percentile cannot be null");
VELOX_USER_CHECK_GE(data[i], 0, "Percentile must be between 0 and 1");
VELOX_USER_CHECK_LE(data[i], 1, "Percentile must be between 0 and 1");
percentiles_->values[i] = data[i];
VELOX_USER_CHECK(!percentiles.isNullAt(i), "Percentile cannot be null");
auto value = percentiles.valueAt<double>(offset + i);
VELOX_USER_CHECK_GE(value, 0, "Percentile must be between 0 and 1");
VELOX_USER_CHECK_LE(value, 1, "Percentile must be between 0 and 1");
percentiles_->values[i] = value;
}
} else {
VELOX_USER_CHECK_EQ(
Expand All @@ -586,7 +566,7 @@ class ApproxPercentileAggregate : public exec::Aggregate {
"Percentile argument must be constant for all input rows");
for (vector_size_t i = 0; i < len; ++i) {
VELOX_USER_CHECK_EQ(
data[i],
percentiles.valueAt<double>(offset + i),
percentiles_->values[i],
"Percentile argument must be constant for all input rows");
}
Expand Down Expand Up @@ -752,12 +732,12 @@ class ApproxPercentileAggregate : public exec::Aggregate {
}

bool isArray = percentileIsArray->valueAt(i);
const double* data;
vector_size_t len;
std::vector<bool> isNull;
extractPercentiles(
percentilesBase, indexInBaseVector, data, len, isNull);
checkSetPercentile(isArray, data, len, isNull);
DecodedVector decodedElements(*percentilesBase->elements());
checkSetPercentile(
isArray,
decodedElements,
percentilesBase->offsetAt(indexInBaseVector),
percentilesBase->sizeAt(indexInBaseVector));

if (!accuracy->isNullAt(i)) {
checkSetAccuracy(accuracy->valueAt(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ TEST_F(ApproxPercentileTest, finalAggregateAccuracy) {
assertQuery(op, "SELECT 5");
}

TEST_F(ApproxPercentileTest, invalidEncoding) {
TEST_F(ApproxPercentileTest, nonFlatPercentileArray) {
auto indices = AlignedBuffer::allocate<vector_size_t>(3, pool());
auto rawIndices = indices->asMutable<vector_size_t>();
std::iota(rawIndices, rawIndices + indices->size(), 0);
Expand All @@ -421,10 +421,8 @@ TEST_F(ApproxPercentileTest, invalidEncoding) {
.values({rows})
.singleAggregation({}, {"approx_percentile(c0, c1)"})
.planNode();
AssertQueryBuilder assertQuery(plan);
VELOX_ASSERT_THROW(
assertQuery.copyResults(pool()),
"Only flat encoding is allowed for percentile array elements");
auto expected = makeRowVector({makeArrayVector<int32_t>({{0, 5, 9}})});
AssertQueryBuilder(plan).assertResults(expected);
}

TEST_F(ApproxPercentileTest, invalidWeight) {
Expand Down