Skip to content

Commit

Permalink
Support all aggregate/window functions (#26)
Browse files Browse the repository at this point in the history
* Check no crash (regular fuzzing) for all aggregate functions

* Support window function regular fuzzing
  • Loading branch information
2010YOUY01 authored Aug 23, 2024
1 parent 2081229 commit 67fe236
Show file tree
Hide file tree
Showing 12 changed files with 469 additions and 94 deletions.
16 changes: 12 additions & 4 deletions src/sqlancer/datafusion/DataFusionErrors.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("MedianAccumulator not supported for median");
errors.add("Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal");
errors.add("digest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal ");
errors.add("Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented");
errors.add("Arrow error: Invalid argument error: Invalid arithmetic operation: Utf8");
errors.add("There is only support Literal types for field at idx:");
errors.add("nth_value not supported for n:");
errors.add("Invalid argument error: Nested comparison: List(");

/*
* Known bugs
Expand All @@ -47,10 +52,12 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("bitwise"); // https://github.com/apache/datafusion/issues/11260
errors.add("Sort expressions cannot be empty for streaming merge."); // https://github.com/apache/datafusion/issues/11561
errors.add("compute_utf8_flag_op_scalar failed to cast literal value NULL for operation"); // https://github.com/apache/datafusion/issues/11623
errors.add("Schema error: No field named"); // https://github.com/apache/datafusion/issues/11635
errors.add("Min/Max accumulator not implemented for type Null."); // https://github.com/apache/datafusion/issues/11749
errors.add("APPROX_PERCENTILE_CONT_WITH_WEIGHT"); // TODO issue
errors.add("APPROX_MEDIAN"); // TODO issue
errors.add("Schema error: No field named "); // https://github.com/apache/datafusion/issues/12006
errors.add("Internal error: PhysicalExpr Column references column"); // https://github.com/apache/datafusion/issues/12012
errors.add("APPROX_"); // https://github.com/apache/datafusion/issues/12058
errors.add("External error: task"); // https://github.com/apache/datafusion/issues/12057
errors.add("NTH_VALUE"); // https://github.com/apache/datafusion/issues/12073
errors.add("SUBSTR"); // https://github.com/apache/datafusion/issues/12129

/*
* False positives
Expand All @@ -59,6 +66,7 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr
// is generated in where
// clause

/*
* Not critical, investigate in the future
*/
Expand Down
15 changes: 12 additions & 3 deletions src/sqlancer/datafusion/DataFusionOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sqlancer.datafusion.DataFusionOptions.DataFusionOracleFactory;
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
import sqlancer.datafusion.test.DataFusionNoCrashAggregate;
import sqlancer.datafusion.test.DataFusionNoCrashWindow;
import sqlancer.datafusion.test.DataFusionNoRECOracle;
import sqlancer.datafusion.test.DataFusionQueryPartitioningAggrTester;
import sqlancer.datafusion.test.DataFusionQueryPartitioningHavingTester;
Expand All @@ -26,10 +27,12 @@ public class DataFusionOptions implements DBMSSpecificOptions<DataFusionOracleFa
@Override
public List<DataFusionOracleFactory> getTestOracleFactory() {
return Arrays.asList(
// DataFusionOracleFactory.NO_CRASH_AGGREGATE
// DataFusionOracleFactory.NO_CRASH_WINDOW,
// DataFusionOracleFactory.NO_CRASH_AGGREGATE,
DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE
/* DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE */
/* , DataFusionOracleFactory.QUERY_PARTITIONING_HAVING */);
// DataFusionOracleFactory.QUERY_PARTITIONING_AGGREGATE
// ,DataFusionOracleFactory.QUERY_PARTITIONING_HAVING
);
}

public enum DataFusionOracleFactory implements OracleFactory<DataFusionGlobalState> {
Expand Down Expand Up @@ -62,6 +65,12 @@ public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalStat
public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalState) throws SQLException {
return new DataFusionNoCrashAggregate(globalState);
}
},
NO_CRASH_WINDOW {
@Override
public TestOracle<DataFusionGlobalState> create(DataFusionGlobalState globalState) throws SQLException {
return new DataFusionNoCrashWindow(globalState);
}
}
}

Expand Down
29 changes: 29 additions & 0 deletions src/sqlancer/datafusion/DataFusionToStringVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sqlancer.datafusion.ast.DataFusionExpression;
import sqlancer.datafusion.ast.DataFusionSelect;
import sqlancer.datafusion.ast.DataFusionSelect.DataFusionFrom;
import sqlancer.datafusion.ast.DataFusionWindowExpr;

public class DataFusionToStringVisitor extends NewToStringVisitor<DataFusionExpression> {

Expand All @@ -34,6 +35,8 @@ public void visitSpecific(Node<DataFusionExpression> expr) {
visit((DataFusionSelect) expr);
} else if (expr instanceof DataFusionFrom) {
visit((DataFusionFrom) expr);
} else if (expr instanceof DataFusionWindowExpr) {
visit((DataFusionWindowExpr) expr);
} else {
throw new AssertionError(expr.getClass());
}
Expand Down Expand Up @@ -134,4 +137,30 @@ private void visit(DataFusionSelect select) {
}
}

private void visit(DataFusionWindowExpr window) {
// Window function
visit(window.windowFunc);

// Over clause
// -----------
sb.append(" OVER (");

if (!window.partitionByList.isEmpty()) {
sb.append("PARTITION BY ");
visit(window.partitionByList);
}

if (!window.orderByList.isEmpty()) {
sb.append(" ORDER BY ");
visit(window.orderByList);
}

if (window.frameClause.isPresent()) {
sb.append(" ");
sb.append(window.frameClause.get());
}

sb.append(")");
}

}
1 change: 1 addition & 0 deletions src/sqlancer/datafusion/DataFusionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public void appendToLog(DataFusionLogType logType, String logContent) {
try {
logFileWriter.write(logLineHeader);
logFileWriter.write(logContent);
logFileWriter.write("\n");
logFileWriter.flush();
} catch (IOException e) {
String err = "Failed to write to " + logType + " log: " + e.getMessage();
Expand Down
10 changes: 1 addition & 9 deletions src/sqlancer/datafusion/ast/DataFusionSelect.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,7 @@ public static DataFusionSelect getRandomSelect(DataFusionGlobalState state) {
// This method assume `DataFusionSelect` is propoerly initialized with `getRandomSelect()`
public void setAggregates(DataFusionGlobalState state) {
// group by exprs (e.g. group by v1, abs(v2))
List<Node<DataFusionExpression>> groupByExprs = new ArrayList<>();
int nGroupBy = state.getRandomly().getInteger(0, 3);
if (Randomly.getBoolean()) {
// Generate expressions like (v1+1, v2 *2)
groupByExprs = this.exprGenGroupBy.generateExpressions(nGroupBy);
} else {
// Generate simple column references like v1, v2
groupByExprs = this.exprGenGroupBy.generateColumns(nGroupBy);
}
List<Node<DataFusionExpression>> groupByExprs = this.exprGenGroupBy.generateExpressionsPreferColumns();

// Generate aggregates like SUM(v1), MAX(V2)
this.exprGenAggregate.supportAggregate = true;
Expand Down
126 changes: 126 additions & 0 deletions src/sqlancer/datafusion/ast/DataFusionWindowExpr.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package sqlancer.datafusion.ast;

import static sqlancer.datafusion.gen.DataFusionExpressionGenerator.getIntegerPreferSmallPositive;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import sqlancer.Randomly;
import sqlancer.common.ast.newast.Node;
import sqlancer.datafusion.DataFusionProvider;
import sqlancer.datafusion.DataFusionSchema;
import sqlancer.datafusion.gen.DataFusionExpressionGenerator;

// This Node is for single window expression
// e.g. `rank() over()`
public class DataFusionWindowExpr implements Node<DataFusionExpression> {
// Window expression syntax reference:
// function([expr])
// OVER(
// [PARTITION BY expr[, …]]
// [ORDER BY expr [ ASC | DESC ][, …]]
// [ frame_clause ]
// )

// Window Function
// ===============
public Node<DataFusionExpression> windowFunc;

// Over Clause components
// ======================

// Optional. Empty list means not present.
public List<Node<DataFusionExpression>> partitionByList = new ArrayList<>();
// Optional. Empty list means not present.
public List<Node<DataFusionExpression>> orderByList = new ArrayList<>();
// Optional. Empty option means not present.
public Optional<String> frameClause;

// Others
// =======

// To generate `partitionByList` and `orderByList`
public DataFusionExpressionGenerator exprGen;

public static DataFusionWindowExpr getRandomWindowClause(DataFusionExpressionGenerator gen,
DataFusionProvider.DataFusionGlobalState state) {
DataFusionWindowExpr windowExpr = new DataFusionWindowExpr();
windowExpr.exprGen = gen;

// setup window function e.g. 'rank()'
windowExpr.exprGen.supportWindow = true;
windowExpr.windowFunc = windowExpr.exprGen
.generateExpression(DataFusionSchema.DataFusionDataType.getRandomWithoutNull());
windowExpr.exprGen.supportWindow = false;

// setup `partition by`
windowExpr.partitionByList = windowExpr.exprGen.generateExpressionsPreferColumns();

// setup `order by`
windowExpr.orderByList = windowExpr.exprGen.generateOrderBys();

// setup frame range
windowExpr.frameClause = DataFusionWindowExprFrame.getRandomFrame(state);

return windowExpr;
}

}

// 'frame_clause' is one of:
// { RANGE | ROWS | GROUPS } frame_start
// { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end
//
// 'frame_start' and 'frame_end' can be:
// UNBOUNDED PRECEDING
// offset PRECEDING
// CURRENT ROW
// offset FOLLOWING
// UNBOUNDED FOLLOWING
//
// offset is non-negative integer
// (but this class might generate something else to make it more chaotic)
//
// Reference:
// https://datafusion.apache.org/user-guide/sql/window_functions.html#syntax
final class DataFusionWindowExprFrame {
private static final List<String> FRAME_TYPES = Arrays.asList("RANGE", "ROWS", "GROUPS");

// Private constructor to prevent instantiation
private DataFusionWindowExprFrame() {
}

// The range epxression inside now only support integer liternal (not expr)
// So make it string instead of Node<DataFusionExpression> for simplicity
public static Optional<String> getRandomFrame(DataFusionProvider.DataFusionGlobalState state) {
if (Randomly.getBoolean()) {
return Optional.empty();
}

String frameType = Randomly.fromList(FRAME_TYPES);
String frameStart = generateFramePoint(state, true);
String frameEnd = generateFramePoint(state, false);

String repr;
if (Randomly.getBoolean()) {
repr = frameType + " " + frameStart;
} else {
repr = frameType + " BETWEEN " + frameStart + " AND " + frameEnd;
}

return Optional.of(repr);
}

private static String generateFramePoint(DataFusionProvider.DataFusionGlobalState state, boolean isStart) {
int offset = getIntegerPreferSmallPositive(state);
List<String> options = new ArrayList<>(Arrays.asList("UNBOUNDED PRECEDING", offset + " PRECEDING",
"CURRENT ROW", offset + " FOLLOWING", "UNBOUNDED FOLLOWING"));

if (!isStart && !Randomly.getBooleanWithRatherLowProbability()) {
options.remove("UNBOUNDED PRECEDING");
}
return Randomly.fromList(options);
}
}
74 changes: 47 additions & 27 deletions src/sqlancer/datafusion/gen/DataFusionBaseExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public String toString() {
*/
// Used to construct `src.common.ast.*Node`
public enum DataFusionBaseExprCategory {
UNARY_PREFIX, UNARY_POSTFIX, BINARY, FUNC, AGGREGATE
UNARY_PREFIX, UNARY_POSTFIX, BINARY, FUNC, AGGREGATE, WINDOW
}

/*
Expand Down Expand Up @@ -313,36 +313,56 @@ public enum DataFusionBaseExprType {
// Other Functions

// Aggregate Functions (General)
AGGR_MIN, AGGR_MAX, AGGR_SUM, AGGR_AVG, AGGR_COUNT, BIT_AND, BIT_OR, BIT_XOR, BOOL_AND, BOOL_OR, MEAN, MEDIAN,
FIRST_VALUE, LAST_VALUE,
AGGR_MIN, AGGR_MAX, AGGR_SUM, AGGR_AVG, AGGR_COUNT, AGGR_BIT_AND, AGGR_BIT_OR, AGGR_BIT_XOR, AGGR_BOOL_AND,
AGGR_BOOL_OR, AGGR_MEAN, AGGR_MEDIAN, AGGR_FIRST_VALUE, AGGR_LAST_VALUE,
// Aggregate Functiosn (Statistical)
CORR, // corr(v1, v2)
COVAR, // covar(v1, v2)
COVAR_POP, // covar_pop(v1, v2)
COVAR_SAMP, // covar_samp(v1, v2)
STDDEV, // stddev(v)
STDDEV_POP, // stddev_pop(v)
STDDEV_SAMP, // stddev_samp(v)
VAR, // var(v)
VAR_POP, // var_pop(v)
VAR_SAMP, // var_samp(v)
REGR_AVGX, // regr_avgx(y, x)
REGR_AVGY, // regr_avgy(y, x)
REGR_COUNT, // regr_count(y, x)
REGR_INTERCEPT, // regr_intercept(y, x)
REGR_R2, // regr_r2(y, x)
REGR_SLOPE, // regr_slope(y, x)
REGR_SXX, // regr_sxx(x)
REGR_SYY, // regr_syy(y)
REGR_SXY, // regr_sxy(x, y)
AGGR_CORR, // corr(v1, v2)
AGGR_COVAR, // covar(v1, v2)
AGGR_POP, // covar_pop(v1, v2)
AGGR_COVAR_SAMP, // covar_samp(v1, v2)
AGGR_STDDEV, // stddev(v)
AGGR_STDDEV_POP, // stddev_pop(v)
AGGR_STDDEV_SAMP, // stddev_samp(v)
AGGR_VAR, // var(v)
AGGR_VAR_POP, // var_pop(v)
AGGR_VAR_SAMP, // var_samp(v)
AGGR_REGR_AVGX, // regr_avgx(y, x)
AGGR_REGR_AVGY, // regr_avgy(y, x)
AGGR_REGR_COUNT, // regr_count(y, x)
AGGR_REGR_INTERCEPT, // regr_intercept(y, x)
AGGR_REGR_R2, // regr_r2(y, x)
AGGR_REGR_SLOPE, // regr_slope(y, x)
AGGR_REGR_SXX, // regr_sxx(x)
AGGR_REGR_SYY, // regr_syy(y)
AGGR_REGR_SXY, // regr_sxy(x, y)
// Aggregate Functions (Approximate)
APPROX_DISTINCT, // approx_distinct(expression)
APPROX_MEDIAN, // approx_median(expression)
APPROX_PERCENTILE_CONT, // approx_percentile_cont(expression, percentile)
APPROX_PERCENTILE_CONT2, // approx_percentile_cont(expression, percentile, centroids)
APPROX_PERCENTILE_CONT_WITH_WEIGHT // approx_percentile_cont_with_weight(expression, weight, percentile)
AGGR_APPROX_DISTINCT, // approx_distinct(expression)
AGGR_APPROX_MEDIAN, // approx_median(expression)
AGGR_APPROX_PERCENTILE_CONT, // approx_percentile_cont(expression, percentile)
AGGR_APPROX_PERCENTILE_CONT2, // approx_percentile_cont(expression, percentile, centroids)
AGGR_APPROX_PERCENTILE_CONT_WITH_WEIGHT, // approx_percentile_cont_with_weight(expression, weight, percentile)

// Array Aggregate functions

// Window Functions
WINDOW_RANK, // rank()
WINDOW_ROW_NUMBER, // row_number()
WINDOW_DENSE_RANK, // dense_rank()
WINDOW_NTILE, // ntile(2) - divides the partition into 2 groups
WINDOW_CUME_DIST, // cume_dist() - Relative rank of the current row: WINDOW(number of rows preceding or peer
// with current row) / (total rows)
WINDOW_PERCENT_RANK, // percent_rank() - Relative rank of the current row:WINDOW (rank - 1) / (total rows - 1)
WINDOW_LAG1, // 1 arg version of lag(expression, offset, default) - Returns value from rows WINDOWbefore the
// current row within the partition
WINDOW_LAG2, // 2 args version
WINDOW_LAG3, // 3 args version
WINDOW_LEAD1, // 1 arg version of lead(expression, offset, default) - Returns value from rows after the current
// row within the partition
WINDOW_LEAD2, // 2 args version
WINDOW_LEAD3, // 3 args version
WINDOW_FIRST_VALUE, // first_value(expression) - Returns value from the first row of the window frame
WINDOW_LAST_VALUE, // last_value(expression) - Returns value from the last row of the window frame
WINDOW_NTH_VALUE, // nth_value(expression, n) - Returns value from the nth row of the window frame
}

/*
Expand Down
Loading

0 comments on commit 67fe236

Please sign in to comment.