Skip to content

Commit

Permalink
Assert correctness on AST ran through SubexpressionOptimizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611245287
  • Loading branch information
l46kok authored and copybara-github committed Feb 28, 2024
1 parent 73d29cf commit b302caa
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
Expand All @@ -39,6 +40,7 @@
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.navigation.CelNavigableAst;
Expand Down Expand Up @@ -125,9 +127,14 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c

@Override
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder) {
return cseOptions.enableCelBlock()
? optimizeUsingCelBlock(navigableAst, celBuilder)
: optimizeUsingCelBind(navigableAst);
CelAbstractSyntaxTree ast =
cseOptions.enableCelBlock()
? optimizeUsingCelBlock(navigableAst, celBuilder)
: optimizeUsingCelBind(navigableAst);

verifyOptimizedAstCorrectness(ast);

return ast;
}

private CelAbstractSyntaxTree optimizeUsingCelBlock(
Expand Down Expand Up @@ -239,6 +246,73 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
return tagAstExtension(astToModify);
}

/**
* Asserts that the optimized AST has no correctness issues.
*
* @throws com.google.common.base.VerifyException if the optimized AST is malformed.
*/
@VisibleForTesting
static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
CelNavigableAst celNavigableAst = CelNavigableAst.fromAst(ast);

ImmutableList<CelExpr> allCelBlocks =
celNavigableAst
.getRoot()
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION))
.collect(toImmutableList());
if (allCelBlocks.isEmpty()) {
return;
}

CelExpr celBlockExpr = allCelBlocks.get(0);
Verify.verify(
allCelBlocks.size() == 1,
"Expected 1 cel.block function to be present but found %s",
allCelBlocks.size());
Verify.verify(
celNavigableAst.getRoot().expr().equals(celBlockExpr),
"Expected cel.block to be present at root");

// Assert correctness on block indices used in subexpressions
CelCall celBlockCall = celBlockExpr.call();
ImmutableList<CelExpr> subexprs = celBlockCall.args().get(0).createList().elements();
for (int i = 0; i < subexprs.size(); i++) {
verifyBlockIndex(subexprs.get(i), i);
}

// Assert correctness on block indices used in block result
CelExpr blockResult = celBlockCall.args().get(1);
verifyBlockIndex(blockResult, subexprs.size());
boolean resultHasAtLeastOneBlockIndex =
CelNavigableExpr.fromExpr(blockResult)
.allNodes()
.map(CelNavigableExpr::expr)
.anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX));
Verify.verify(
resultHasAtLeastOneBlockIndex,
"Expected at least one reference of index in cel.block result");
}

private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
boolean areAllIndicesValid =
CelNavigableExpr.fromExpr(celExpr)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX))
.map(CelExpr::ident)
.allMatch(
blockIdent ->
Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length()))
< maxIndexValue);
Verify.verify(
areAllIndicesValid,
"Illegal block index found. The index value must be less than %s. Expr: %s",
maxIndexValue,
celExpr);
}

private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
// Tag the extension
CelSource.Builder celSourceBuilder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static dev.cel.common.CelOverloadDecl.newGlobalOverload;
import static org.junit.Assert.assertThrows;

import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.testing.junit.testparameterinjector.TestParameter;
Expand Down Expand Up @@ -481,6 +482,73 @@ public void blockIndex_invalidArgument_throws() {
assertThat(e).hasMessageThat().contains("undeclared reference");
}

@Test
public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception {
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions("cel.block([1, 2], cel.block([2], 3))");

VerifyException e =
assertThrows(
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.isEqualTo("Expected 1 cel.block function to be present but found 2");
}

@Test
public void verifyOptimizedAstCorrectness_celBlockNotAtRoot_throws() throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("1 + cel.block([1, 2], index0)");

VerifyException e =
assertThrows(
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e).hasMessageThat().isEqualTo("Expected cel.block to be present at root");
}

@Test
public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([1, index0], 2)");

VerifyException e =
assertThrows(
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.isEqualTo("Expected at least one reference of index in cel.block result");
}

@Test
@TestParameters("{source: 'cel.block([], index0)'}")
@TestParameters("{source: 'cel.block([1, 2], index2)'}")
public void verifyOptimizedAstCorrectness_indexOutOfBounds_throws(String source)
throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);

VerifyException e =
assertThrows(
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.contains("Illegal block index found. The index value must be less than");
}

@Test
@TestParameters("{source: 'cel.block([index0], index0)'}")
@TestParameters("{source: 'cel.block([1, index1, 2], index2)'}")
@TestParameters("{source: 'cel.block([1, 2, index2], index2)'}")
@TestParameters("{source: 'cel.block([index2, 1, 2], index2)'}")
public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(String source)
throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);

VerifyException e =
assertThrows(
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.contains("Illegal block index found. The index value must be less than");
}

/**
* Converts AST containing cel.block related test functions to internal functions (e.g: cel.block
* -> cel.@block)
Expand Down

0 comments on commit b302caa

Please sign in to comment.