diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index a8e431ab..b55f4614 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -22,6 +22,7 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -39,6 +40,7 @@ import dev.cel.common.CelValidationException; import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.navigation.CelNavigableAst; @@ -125,9 +127,14 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c @Override public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder) { - return cseOptions.enableCelBlock() - ? optimizeUsingCelBlock(navigableAst, celBuilder) - : optimizeUsingCelBind(navigableAst); + CelAbstractSyntaxTree ast = + cseOptions.enableCelBlock() + ? optimizeUsingCelBlock(navigableAst, celBuilder) + : optimizeUsingCelBind(navigableAst); + + verifyOptimizedAstCorrectness(ast); + + return ast; } private CelAbstractSyntaxTree optimizeUsingCelBlock( @@ -239,6 +246,73 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( return tagAstExtension(astToModify); } + /** + * Asserts that the optimized AST has no correctness issues. + * + * @throws com.google.common.base.VerifyException if the optimized AST is malformed. + */ + @VisibleForTesting + static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) { + CelNavigableAst celNavigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allCelBlocks = + celNavigableAst + .getRoot() + .allNodes() + .map(CelNavigableExpr::expr) + .filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION)) + .collect(toImmutableList()); + if (allCelBlocks.isEmpty()) { + return; + } + + CelExpr celBlockExpr = allCelBlocks.get(0); + Verify.verify( + allCelBlocks.size() == 1, + "Expected 1 cel.block function to be present but found %s", + allCelBlocks.size()); + Verify.verify( + celNavigableAst.getRoot().expr().equals(celBlockExpr), + "Expected cel.block to be present at root"); + + // Assert correctness on block indices used in subexpressions + CelCall celBlockCall = celBlockExpr.call(); + ImmutableList subexprs = celBlockCall.args().get(0).createList().elements(); + for (int i = 0; i < subexprs.size(); i++) { + verifyBlockIndex(subexprs.get(i), i); + } + + // Assert correctness on block indices used in block result + CelExpr blockResult = celBlockCall.args().get(1); + verifyBlockIndex(blockResult, subexprs.size()); + boolean resultHasAtLeastOneBlockIndex = + CelNavigableExpr.fromExpr(blockResult) + .allNodes() + .map(CelNavigableExpr::expr) + .anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX)); + Verify.verify( + resultHasAtLeastOneBlockIndex, + "Expected at least one reference of index in cel.block result"); + } + + private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) { + boolean areAllIndicesValid = + CelNavigableExpr.fromExpr(celExpr) + .allNodes() + .map(CelNavigableExpr::expr) + .filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX)) + .map(CelExpr::ident) + .allMatch( + blockIdent -> + Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length())) + < maxIndexValue); + Verify.verify( + areAllIndicesValid, + "Illegal block index found. The index value must be less than %s. Expr: %s", + maxIndexValue, + celExpr); + } + private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) { // Tag the extension CelSource.Builder celSourceBuilder = diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 3668d776..a67bb525 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -18,6 +18,7 @@ import static dev.cel.common.CelOverloadDecl.newGlobalOverload; import static org.junit.Assert.assertThrows; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameter; @@ -481,6 +482,73 @@ public void blockIndex_invalidArgument_throws() { assertThat(e).hasMessageThat().contains("undeclared reference"); } + @Test + public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception { + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions("cel.block([1, 2], cel.block([2], 3))"); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e) + .hasMessageThat() + .isEqualTo("Expected 1 cel.block function to be present but found 2"); + } + + @Test + public void verifyOptimizedAstCorrectness_celBlockNotAtRoot_throws() throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions("1 + cel.block([1, 2], index0)"); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e).hasMessageThat().isEqualTo("Expected cel.block to be present at root"); + } + + @Test + public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([1, index0], 2)"); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e) + .hasMessageThat() + .isEqualTo("Expected at least one reference of index in cel.block result"); + } + + @Test + @TestParameters("{source: 'cel.block([], index0)'}") + @TestParameters("{source: 'cel.block([1, 2], index2)'}") + public void verifyOptimizedAstCorrectness_indexOutOfBounds_throws(String source) + throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e) + .hasMessageThat() + .contains("Illegal block index found. The index value must be less than"); + } + + @Test + @TestParameters("{source: 'cel.block([index0], index0)'}") + @TestParameters("{source: 'cel.block([1, index1, 2], index2)'}") + @TestParameters("{source: 'cel.block([1, 2, index2], index2)'}") + @TestParameters("{source: 'cel.block([index2, 1, 2], index2)'}") + public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(String source) + throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e) + .hasMessageThat() + .contains("Illegal block index found. The index value must be less than"); + } + /** * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block * -> cel.@block)