diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index aef7de69a8d..02b97baa939 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -6,6 +6,7 @@ package org.opensearch.sql.analysis; +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST; import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; @@ -44,6 +45,7 @@ import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; @@ -64,7 +66,6 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprMissingValue; import org.opensearch.sql.data.type.ExprCoreType; -import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -84,6 +85,7 @@ import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalML; @@ -211,7 +213,6 @@ public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext contex tableFunctionImplementation.applyArguments()); } - @Override public LogicalPlan visitLimit(Limit node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); @@ -587,4 +588,9 @@ private SortOption analyzeSortOption(List fieldArgs) { return asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC; } + @Override + public LogicalPlan visitFetchCursor(FetchCursor cursor, AnalysisContext context) { + return new LogicalFetchCursor(cursor.getCursor(), + dataSourceService.getDataSource(DEFAULT_DATASOURCE_NAME).getStorageEngine()); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 9c283d95f6d..beb4833d4d2 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -43,6 +43,7 @@ import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; @@ -299,4 +300,8 @@ public T visitExplain(Explain node, C context) { public T visitPaginate(Paginate paginate, C context) { return visitChildren(paginate, context); } + + public T visitFetchCursor(FetchCursor cursor, C context) { + return visit(cursor, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java b/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java new file mode 100644 index 00000000000..aa327c295b0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * An unresolved plan that represents fetching the next + * batch in paginationed plan. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class FetchCursor extends UnresolvedPlan { + @Getter + final String cursor; + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFetchCursor(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + throw new UnsupportedOperationException("Cursor unresolved plan does not support children"); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index a4cd1982cd3..94e70819204 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -46,14 +46,6 @@ public void execute(UnresolvedPlan plan, } } - /** - * Execute a physical plan without analyzing or planning anything. - */ - public void executePlan(PhysicalPlan plan, - ResponseListener listener) { - executionEngine.execute(plan, ExecutionContext.emptyExecutionContext(), listener); - } - /** * Execute the {@link UnresolvedPlan}, with {@link PlanContext} and using {@link ResponseListener} * to get response. diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java deleted file mode 100644 index eda65aba2da..00000000000 --- a/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.executor.execution; - -import org.opensearch.sql.common.response.ResponseListener; -import org.opensearch.sql.executor.ExecutionEngine; -import org.opensearch.sql.executor.QueryId; -import org.opensearch.sql.executor.QueryService; -import org.opensearch.sql.executor.pagination.PlanSerializer; -import org.opensearch.sql.planner.physical.PhysicalPlan; - -/** - * ContinuePaginatedPlan represents cursor a request. - * It returns subsequent pages to the user (2nd page and all next). - */ -public class ContinuePaginatedPlan extends AbstractPlan { - - private final String cursor; - private final QueryService queryService; - private final PlanSerializer planSerializer; - - private final ResponseListener queryResponseListener; - - - /** - * Create an abstract plan that can continue paginating a given cursor. - */ - public ContinuePaginatedPlan(QueryId queryId, String cursor, QueryService queryService, - PlanSerializer planCache, - ResponseListener - queryResponseListener) { - super(queryId); - this.cursor = cursor; - this.planSerializer = planCache; - this.queryService = queryService; - this.queryResponseListener = queryResponseListener; - } - - @Override - public void execute() { - try { - PhysicalPlan plan = planSerializer.convertToPlan(cursor); - queryService.executePlan(plan, queryResponseListener); - } catch (Exception e) { - queryResponseListener.onFailure(e); - } - } - - @Override - public void explain(ResponseListener listener) { - listener.onFailure(new UnsupportedOperationException( - "Explain of a paged query continuation is not supported. " - + "Use `explain` for the initial query request.")); - } -} diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java index 18455c2a021..cc53f5060b2 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java @@ -17,12 +17,14 @@ import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryService; -import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; /** * QueryExecution Factory. @@ -39,7 +41,6 @@ public class QueryPlanFactory * Query Service. */ private final QueryService queryService; - private final PlanSerializer planSerializer; /** * NO_CONSUMER_RESPONSE_LISTENER should never be called. It is only used as constructor @@ -65,7 +66,7 @@ public void onFailure(Exception e) { /** * Create QueryExecution from Statement. */ - public AbstractPlan createContinuePaginatedPlan( + public AbstractPlan create( Statement statement, Optional> queryListener, Optional> explainListener) { @@ -73,17 +74,20 @@ public AbstractPlan createContinuePaginatedPlan( } /** - * Creates a ContinuePaginatedPlan from a cursor. + * Creates a QueryPlan from a cursor. */ - public AbstractPlan createContinuePaginatedPlan(String cursor, boolean isExplain, - ResponseListener queryResponseListener, - ResponseListener explainListener) { + public AbstractPlan create(String cursor, boolean isExplain, + ResponseListener queryResponseListener, + ResponseListener explainListener) { QueryId queryId = QueryId.queryId(); - var plan = new ContinuePaginatedPlan(queryId, cursor, queryService, - planSerializer, queryResponseListener); + var plan = new QueryPlan(queryId, new FetchCursor(cursor), queryService, queryResponseListener); return isExplain ? new ExplainPlan(queryId, plan, explainListener) : plan; } + boolean canConvertToCursor(UnresolvedPlan plan) { + return plan.accept(new CanPaginateVisitor(), null); + } + @Override public AbstractPlan visitQuery( Query node, @@ -94,7 +98,7 @@ public AbstractPlan visitQuery( context.getLeft().isPresent(), "[BUG] query listener must be not null"); if (node.getFetchSize() > 0) { - if (planSerializer.canConvertToCursor(node.getPlan())) { + if (canConvertToCursor(node.getPlan())) { return new QueryPlan(QueryId.queryId(), node.getPlan(), node.getFetchSize(), queryService, context.getLeft().get()); @@ -119,7 +123,7 @@ public AbstractPlan visitExplain( return new ExplainPlan( QueryId.queryId(), - createContinuePaginatedPlan(node.getStatement(), + create(node.getStatement(), Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), context.getRight().get()); } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java index d6d10ee89cf..07cf174d73f 100644 --- a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java @@ -18,7 +18,6 @@ import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import lombok.RequiredArgsConstructor; -import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.exception.NoCursorException; import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -34,9 +33,6 @@ public class PlanSerializer { private final StorageEngine engine; - public boolean canConvertToCursor(UnresolvedPlan plan) { - return plan.accept(new CanPaginateVisitor(), null); - } /** * Converts a physical plan tree to a cursor. diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index 9bde4ab6474..a1897245ea7 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -6,9 +6,11 @@ package org.opensearch.sql.planner; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; @@ -148,6 +150,11 @@ public PhysicalPlan visitRelation(LogicalRelation node, C context) { + "implementing and optimizing logical plan with relation involved"); } + @Override + public PhysicalPlan visitFetchCursor(LogicalFetchCursor plan, C context) { + return new PlanSerializer(plan.getEngine()).convertToPlan(plan.getCursor()); + } + protected PhysicalPlan visitChild(LogicalPlan node, C context) { // Logical operators visited here must have a single child return node.getChild().get(0).accept(this, context); diff --git a/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java index 487b1da6bde..ab195da5bfb 100644 --- a/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java +++ b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java @@ -6,10 +6,6 @@ package org.opensearch.sql.planner; import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import org.opensearch.sql.executor.pagination.PlanSerializer; /** * All subtypes of PhysicalPlan which needs to be serialized (in cursor, for pagination feature) @@ -29,21 +25,6 @@ */ public interface SerializablePlan extends Externalizable { - /** - * Argument is an instance of {@link PlanSerializer.CursorDeserializationStream}. - */ - @Override - void readExternal(ObjectInput in) throws IOException, ClassNotFoundException; - - /** - * Each plan which has as a child plan should do. - *
{@code
-   * out.writeObject(input.getPlanForSerialization());
-   * }
- */ - @Override - void writeExternal(ObjectOutput out) throws IOException; - /** * Override to return child or delegated plan, so parent plan should skip this one * for serialization, but it should try to serialize grandchild plan. @@ -55,6 +36,10 @@ public interface SerializablePlan extends Externalizable { * * In that case only plans A and C should be attempted to serialize. * It is needed to skip a `ResourceMonitorPlan` instance only, actually. + * + *
{@code
+   *    * A.writeObject(B.getPlanForSerialization());
+   *  }
* @return Next plan for serialization. */ default SerializablePlan getPlanForSerialization() { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java new file mode 100644 index 00000000000..d9a426dfe7e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.storage.StorageEngine; + +@EqualsAndHashCode(callSuper = false) +@ToString +public class LogicalFetchCursor extends LogicalPlan { + @Getter + private final String cursor; + + @Getter + private final StorageEngine engine; + + /** + * LogicalCursor constructor. Does not have child plans. + */ + public LogicalFetchCursor(String cursor, StorageEngine engine) { + super(List.of()); + this.cursor = cursor; + this.engine = engine; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitFetchCursor(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index e95e47a013c..c0e253ca50e 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -11,19 +11,18 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; -import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; /** @@ -32,6 +31,10 @@ @UtilityClass public class LogicalPlanDSL { + public static LogicalPlan fetchCursor(String cursor, StorageEngine engine) { + return new LogicalFetchCursor(cursor, engine); + } + public static LogicalPlan write(LogicalPlan input, Table table, List columns) { return new LogicalWrite(input, table, columns); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index b3d63e843f7..796fb50f260 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -108,4 +108,8 @@ public R visitAD(LogicalAD plan, C context) { public R visitPaginate(LogicalPaginate plan, C context) { return visitNode(plan, context); } + + public R visitFetchCursor(LogicalFetchCursor plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index afe86d0cb1c..be1227c1dae 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -13,7 +13,6 @@ import java.util.List; import java.util.stream.Collectors; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.rule.CreatePagingTableScanBuilder; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; @@ -52,11 +51,11 @@ public static LogicalPlanOptimizer create() { * Phase 2: Transformations that rely on data source push down capability */ new CreateTableScanBuilder(), - new CreatePagingTableScanBuilder(), TableScanPushDown.PUSH_DOWN_FILTER, TableScanPushDown.PUSH_DOWN_AGGREGATION, TableScanPushDown.PUSH_DOWN_SORT, TableScanPushDown.PUSH_DOWN_LIMIT, + new PushDownPageSize(), TableScanPushDown.PUSH_DOWN_HIGHLIGHT, TableScanPushDown.PUSH_DOWN_NESTED, TableScanPushDown.PUSH_DOWN_PROJECT, diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java new file mode 100644 index 00000000000..8150de824da --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Optional; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * A {@link LogicalPlanOptimizer} rule that pushes down page size + * to table scan builder. + */ +public class PushDownPageSize implements Rule { + @Override + public Pattern pattern() { + return Pattern.typeOf(LogicalPaginate.class) + .matching(lp -> findTableScanBuilder(lp).isPresent()); + } + + @Override + public LogicalPlan apply(LogicalPaginate plan, Captures captures) { + + var builder = findTableScanBuilder(plan).orElseThrow(); + if (!builder.pushDownPageSize(plan)) { + throw new IllegalStateException("Failed to push down LogicalPaginate"); + } + return plan.getChild().get(0); + } + + private Optional findTableScanBuilder(LogicalPaginate logicalPaginate) { + Deque plans = new ArrayDeque<>(); + plans.add(logicalPaginate); + do { + var plan = plans.removeFirst(); + var children = plan.getChild(); + if (children.stream().anyMatch(TableScanBuilder.class::isInstance)) { + if (children.size() > 1) { + throw new UnsupportedOperationException( + "Unsupported plan: relation operator cannot have siblings"); + } + return Optional.of((TableScanBuilder) children.get(0)); + } + plans.addAll(children); + } while (!plans.isEmpty()); + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java deleted file mode 100644 index c635400c333..00000000000 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.planner.optimizer.rule; - -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.List; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.planner.logical.LogicalPaginate; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Rule to create a paged TableScanBuilder in pagination request. - */ -public class CreatePagingTableScanBuilder implements Rule { - /** Capture the table inside matched logical paginate operator. */ - private LogicalPlan relationParent = null; - /** Pattern that matches logical relation operator. */ - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor. - */ - public CreatePagingTableScanBuilder() { - this.pattern = Pattern.typeOf(LogicalPaginate.class).matching(this::findLogicalRelation); - } - - /** - * Finds an instance of LogicalRelation and saves a reference in relationParent variable. - * @param logicalPaginate An instance of LogicalPaginate - * @return true if {@link LogicalRelation} node was found among the descendents of - * {@link this.logicalPaginate}, false otherwise. - */ - private boolean findLogicalRelation(LogicalPaginate logicalPaginate) { - Deque plans = new ArrayDeque<>(); - plans.add(logicalPaginate); - do { - final var plan = plans.removeFirst(); - final var children = plan.getChild(); - if (children.stream().anyMatch(LogicalRelation.class::isInstance)) { - if (children.size() > 1) { - throw new UnsupportedOperationException( - "Unsupported plan: relation operator cannot have siblings"); - } - relationParent = plan; - return true; - } - plans.addAll(children); - } while (!plans.isEmpty()); - return false; - } - - - @Override - public LogicalPlan apply(LogicalPaginate plan, Captures captures) { - var logicalRelation = (LogicalRelation) relationParent.getChild().get(0); - var scan = logicalRelation.getTable().createPagedScanBuilder(plan.getPageSize()); - relationParent.replaceChildPlans(List.of(scan)); - - return plan.getChild().get(0); - } -} diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index 0194f1d03e7..fc1def5a2e6 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -100,8 +100,4 @@ default StreamingSource asStreamingSource() { throw new UnsupportedOperationException(); } - default TableScanBuilder createPagedScanBuilder(int pageSize) { - var error = String.format("'%s' does not support pagination", getClass().toString()); - throw new UnsupportedOperationException(error); - } } diff --git a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java index 9af66e219fa..f0158c52b80 100644 --- a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java +++ b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java @@ -11,6 +11,7 @@ import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; import org.opensearch.sql.planner.logical.LogicalProject; @@ -28,7 +29,7 @@ public abstract class TableScanBuilder extends LogicalPlan { /** * Construct and initialize children to empty list. */ - public TableScanBuilder() { + protected TableScanBuilder() { super(Collections.emptyList()); } @@ -116,6 +117,10 @@ public boolean pushDownNested(LogicalNested nested) { return false; } + public boolean pushDownPageSize(LogicalPaginate paginate) { + return false; + } + @Override public R accept(LogicalPlanNodeVisitor visitor, C context) { return visitor.visitTableScanBuilder(this, context); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 20927f262c4..dda359a7dfe 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -75,6 +75,7 @@ import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Paginate; @@ -90,6 +91,7 @@ import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPaginate; @@ -1641,4 +1643,12 @@ public void visit_paginate() { assertTrue(actual instanceof LogicalPaginate); assertEquals(10, ((LogicalPaginate) actual).getPageSize()); } + + @Test + void visit_cursor() { + LogicalPlan actual = analyze((new FetchCursor("test"))); + assertTrue(actual instanceof LogicalFetchCursor); + assertEquals(new LogicalFetchCursor("test", + dataSourceService.getDataSource("@opensearch").getStorageEngine()), actual); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java deleted file mode 100644 index 3e08280acbe..00000000000 --- a/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.executor.execution; - -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.CALLS_REAL_METHODS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; - -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.opensearch.sql.common.response.ResponseListener; -import org.opensearch.sql.executor.DefaultExecutionEngine; -import org.opensearch.sql.executor.ExecutionEngine; -import org.opensearch.sql.executor.QueryId; -import org.opensearch.sql.executor.QueryService; -import org.opensearch.sql.executor.pagination.PlanSerializer; -import org.opensearch.sql.planner.physical.PhysicalPlan; -import org.opensearch.sql.storage.StorageEngine; - -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class ContinuePaginatedPlanTest { - - private static PlanSerializer planSerializer; - - private static QueryService queryService; - - /** - * Initialize the mocks. - */ - @BeforeAll - public static void setUp() { - var storageEngine = mock(StorageEngine.class); - planSerializer = new PlanSerializer(storageEngine); - queryService = new QueryService(null, new DefaultExecutionEngine(), null); - } - - @Test - public void can_execute_plan() { - var planSerializer = mock(PlanSerializer.class); - when(planSerializer.convertToPlan(anyString())).thenReturn(mock(PhysicalPlan.class)); - var listener = new ResponseListener() { - @Override - public void onResponse(ExecutionEngine.QueryResponse response) { - assertNotNull(response); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }; - var plan = new ContinuePaginatedPlan(QueryId.queryId(), "", - queryService, planSerializer, listener); - plan.execute(); - } - - @Test - public void can_handle_error_while_executing_plan() { - var listener = new ResponseListener() { - @Override - public void onResponse(ExecutionEngine.QueryResponse response) { - fail(); - } - - @Override - public void onFailure(Exception e) { - assertNotNull(e); - } - }; - var plan = new ContinuePaginatedPlan(QueryId.queryId(), "", queryService, - planSerializer, listener); - plan.execute(); - } - - @Test - public void explain_is_not_supported() { - var listener = mock(ResponseListener.class); - mock(ContinuePaginatedPlan.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)) - .explain(listener); - verify(listener).onFailure(any(UnsupportedOperationException.class)); - } -} diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java index 6bdbf1c4c9d..c35d506fe7f 100644 --- a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java @@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.sql.executor.execution.QueryPlanFactory.NO_CONSUMER_RESPONSE_LISTENER; @@ -29,7 +30,7 @@ import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryService; -import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; @ExtendWith(MockitoExtension.class) class QueryPlanFactoryTest { @@ -49,20 +50,18 @@ class QueryPlanFactoryTest { @Mock private ExecutionEngine.QueryResponse queryResponse; - @Mock - private PlanSerializer planSerializer; private QueryPlanFactory factory; @BeforeEach void init() { - factory = new QueryPlanFactory(queryService, planSerializer); + factory = new QueryPlanFactory(queryService); } @Test public void createFromQueryShouldSuccess() { Statement query = new Query(plan, 0); AbstractPlan queryExecution = - factory.createContinuePaginatedPlan(query, Optional.of(queryListener), Optional.empty()); + factory.create(query, Optional.of(queryListener), Optional.empty()); assertTrue(queryExecution instanceof QueryPlan); } @@ -70,18 +69,18 @@ public void createFromQueryShouldSuccess() { public void createFromExplainShouldSuccess() { Statement query = new Explain(new Query(plan, 0)); AbstractPlan queryExecution = - factory.createContinuePaginatedPlan(query, Optional.empty(), Optional.of(explainListener)); + factory.create(query, Optional.empty(), Optional.of(explainListener)); assertTrue(queryExecution instanceof ExplainPlan); } @Test public void createFromCursorShouldSuccess() { - AbstractPlan queryExecution = factory.createContinuePaginatedPlan("", false, + AbstractPlan queryExecution = factory.create("", false, queryListener, explainListener); - AbstractPlan explainExecution = factory.createContinuePaginatedPlan("", true, + AbstractPlan explainExecution = factory.create("", true, queryListener, explainListener); assertAll( - () -> assertTrue(queryExecution instanceof ContinuePaginatedPlan), + () -> assertTrue(queryExecution instanceof QueryPlan), () -> assertTrue(explainExecution instanceof ExplainPlan) ); } @@ -91,7 +90,7 @@ public void createFromQueryWithoutQueryListenerShouldThrowException() { Statement query = new Query(plan, 0); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.createContinuePaginatedPlan( + assertThrows(IllegalArgumentException.class, () -> factory.create( query, Optional.empty(), Optional.empty())); assertEquals("[BUG] query listener must be not null", exception.getMessage()); } @@ -101,7 +100,7 @@ public void createFromExplainWithoutExplainListenerShouldThrowException() { Statement query = new Explain(new Query(plan, 0)); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.createContinuePaginatedPlan( + assertThrows(IllegalArgumentException.class, () -> factory.create( query, Optional.empty(), Optional.empty())); assertEquals("[BUG] explain listener must be not null", exception.getMessage()); } @@ -125,21 +124,21 @@ public void noConsumerResponseChannel() { @Test public void createQueryWithFetchSizeWhichCanBePaged() { - when(planSerializer.canConvertToCursor(plan)).thenReturn(true); - factory = new QueryPlanFactory(queryService, planSerializer); + when(plan.accept(any(CanPaginateVisitor.class), any())).thenReturn(Boolean.TRUE); + factory = new QueryPlanFactory(queryService); Statement query = new Query(plan, 10); AbstractPlan queryExecution = - factory.createContinuePaginatedPlan(query, Optional.of(queryListener), Optional.empty()); + factory.create(query, Optional.of(queryListener), Optional.empty()); assertTrue(queryExecution instanceof QueryPlan); } @Test public void createQueryWithFetchSizeWhichCannotBePaged() { - when(planSerializer.canConvertToCursor(plan)).thenReturn(false); - factory = new QueryPlanFactory(queryService, planSerializer); + when(plan.accept(any(CanPaginateVisitor.class), any())).thenReturn(Boolean.FALSE); + factory = new QueryPlanFactory(queryService); Statement query = new Query(plan, 10); assertThrows(UnsupportedCursorRequestException.class, - () -> factory.createContinuePaginatedPlan(query, + () -> factory.create(query, Optional.of(queryListener), Optional.empty())); } } diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java index b1e97920c89..8211a3bc12f 100644 --- a/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java @@ -15,27 +15,20 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.io.Serializable; -import java.util.List; import lombok.SneakyThrows; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.opensearch.sql.ast.dsl.AstDSL; -import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.exception.NoCursorException; import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; -import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.utils.TestOperator; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class PlanSerializerTest { @@ -50,23 +43,6 @@ void setUp() { planCache = new PlanSerializer(storageEngine); } - @Test - void canConvertToCursor_relation() { - assertTrue(planCache.canConvertToCursor(AstDSL.relation("Table"))); - } - - @Test - void canConvertToCursor_project_allFields_relation() { - var unresolvedPlan = AstDSL.project(AstDSL.relation("table"), AstDSL.allFields()); - assertTrue(planCache.canConvertToCursor(unresolvedPlan)); - } - - @Test - void canConvertToCursor_project_some_fields_relation() { - var unresolvedPlan = AstDSL.project(AstDSL.relation("table"), AstDSL.field("rando")); - Assertions.assertFalse(planCache.canConvertToCursor(unresolvedPlan)); - } - @ParameterizedTest @ValueSource(strings = {"pewpew", "asdkfhashdfjkgakgfwuigfaijkb", "ajdhfgajklghadfjkhgjkadhgad" + "kadfhgadhjgfjklahdgqheygvskjfbvgsdklgfuirehiluANUIfgauighbahfuasdlhfnhaughsdlfhaughaggf" @@ -112,7 +88,7 @@ void serialize_deserialize_obj() { void serialize_throws() { assertThrows(Throwable.class, () -> serialize(new NotSerializableTestClass())); var testObj = new TestOperator(); - testObj.throwIoOnWrite = true; + testObj.setThrowIoOnWrite(true); assertThrows(Throwable.class, () -> serialize(testObj)); } @@ -130,7 +106,7 @@ void deserialize_throws() { @SneakyThrows void convertToCursor_returns_no_cursor_if_cant_serialize() { var plan = new TestOperator(42); - plan.throwNoCursorOnWrite = true; + plan.setThrowNoCursorOnWrite(true); assertAll( () -> assertThrows(NoCursorException.class, () -> serialize(plan)), () -> assertEquals(Cursor.None, planCache.convertToCursor(plan)) @@ -191,60 +167,6 @@ void resolveObject() { // Helpers and auxiliary classes section below - public static class TestOperator extends PhysicalPlan implements SerializablePlan { - private int field; - private boolean throwNoCursorOnWrite = false; - private boolean throwIoOnWrite = false; - - public TestOperator() { - } - - public TestOperator(int value) { - field = value; - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - field = in.readInt(); - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - if (throwNoCursorOnWrite) { - throw new NoCursorException(); - } - if (throwIoOnWrite) { - throw new IOException(); - } - out.writeInt(field); - } - - @Override - public boolean equals(Object o) { - return field == ((TestOperator) o).field; - } - - @Override - public R accept(PhysicalPlanNodeVisitor visitor, C context) { - return null; - } - - @Override - public boolean hasNext() { - return false; - } - - @Override - public ExprValue next() { - return null; - } - - @Override - public List getChild() { - return null; - } - } - @SneakyThrows private String serialize(Serializable input) { return new PlanSerializer(null).serialize(input); diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index bf1464f5f67..d43cb89a3ed 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -46,6 +46,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; @@ -54,17 +55,18 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.expression.window.ranking.RowNumberFunction; -import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; import org.opensearch.sql.storage.write.TableWriteOperator; +import org.opensearch.sql.utils.TestOperator; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -222,6 +224,16 @@ public void visitWindowOperator_should_return_PhysicalWindowOperator() { assertEquals(physicalPlan, logicalPlan.accept(implementor, null)); } + @Test + void visitLogicalCursor_deserializes_it() { + var engine = Mockito.mock(StorageEngine.class); + + var physicalPlan = new TestOperator(); + var logicalPlan = LogicalPlanDSL.fetchCursor(new PlanSerializer(engine) + .convertToCursor(physicalPlan).toString(), engine); + assertEquals(physicalPlan, logicalPlan.accept(implementor, null)); + } + @Test public void visitTableScanBuilder_should_build_TableScanOperator() { TableScanOperator tableScanOperator = Mockito.mock(TableScanOperator.class); diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 34e0e39d872..e826a13f6c4 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -37,6 +37,7 @@ import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -130,9 +131,11 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalNested nested = new LogicalNested(null, nestedArgs, projectList); + LogicalFetchCursor cursor = new LogicalFetchCursor("n:test", mock(StorageEngine.class)); return Stream.of( relation, tableScanBuilder, write, tableWriteBuilder, filter, aggregation, rename, project, - remove, eval, sort, dedup, window, rareTopN, highlight, mlCommons, ad, ml, paginate, nested + remove, eval, sort, dedup, window, rareTopN, highlight, mlCommons, ad, ml, paginate, nested, + cursor ).map(Arguments::of); } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 543b261d9ef..faedb881113 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -9,9 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; @@ -52,9 +50,8 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.rule.CreatePagingTableScanBuilder; -import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -70,13 +67,9 @@ class LogicalPlanOptimizerTest { @Spy private TableScanBuilder tableScanBuilder; - @Spy - private TableScanBuilder pagedTableScanBuilder; - @BeforeEach void setUp() { lenient().when(table.createScanBuilder()).thenReturn(tableScanBuilder); - lenient().when(table.createPagedScanBuilder(anyInt())).thenReturn(pagedTableScanBuilder); } /** @@ -344,45 +337,50 @@ public PhysicalPlan implement(LogicalPlan plan) { @Test void paged_table_scan_builder_support_project_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(true); var relation = relation("schema", table); + var optimized = LogicalPlanOptimizer.create() + .optimize(paginate(project(relation), 4)); + verify(tableScanBuilder).pushDownPageSize(any()); - assertEquals( - project(pagedTableScanBuilder), - LogicalPlanOptimizer.create().optimize(paginate(project(relation), 4))); + assertEquals(project(tableScanBuilder), optimized); } - @Test - void push_page_size_noop_if_no_relation() { - var paginate = new LogicalPaginate(42, List.of(project(values()))); - assertEquals(paginate, LogicalPlanOptimizer.create().optimize(paginate)); + void push_down_page_size_multiple_children() { + var relation = relation("schema", table); + var twoChildrenPlan = new LogicalPlan(List.of(relation, relation)) { + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return null; + } + }; + var queryPlan = paginate(twoChildrenPlan, 4); + var optimizer = LogicalPlanOptimizer.create(); + final var exception = assertThrows(UnsupportedOperationException.class, + () -> optimizer.optimize(queryPlan)); + assertEquals("Unsupported plan: relation operator cannot have siblings", + exception.getMessage()); } @Test - void pagination_optimizer_simple_query() { - var projectPlan = project(relation("schema", table), DSL.named(DSL.ref("intV", INTEGER))); - - var optimizer = new LogicalPlanOptimizer( - List.of(new CreateTableScanBuilder(), new CreatePagingTableScanBuilder())); + void push_down_page_size_push_failed() { + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(false); - { - optimizer.optimize(projectPlan); - verify(table).createScanBuilder(); - verify(table, never()).createPagedScanBuilder(anyInt()); - } + var queryPlan = paginate( + project( + relation("schema", table)), 4); + var optimizer = LogicalPlanOptimizer.create(); + final var exception = assertThrows(IllegalStateException.class, + () -> optimizer.optimize(queryPlan)); + assertEquals("Failed to push down LogicalPaginate", exception.getMessage()); } @Test - void pagination_optimizer_paged_query() { - var relation = new LogicalRelation("schema", table); - var projectPlan = project(relation, DSL.named(DSL.ref("intV", INTEGER))); - var pagedPlan = new LogicalPaginate(10, List.of(projectPlan)); - - var optimizer = new LogicalPlanOptimizer( - List.of(new CreateTableScanBuilder(), new CreatePagingTableScanBuilder())); - var optimized = optimizer.optimize(pagedPlan); - verify(table).createPagedScanBuilder(anyInt()); + void push_page_size_noop_if_no_relation() { + var paginate = new LogicalPaginate(42, List.of(project(values()))); + assertEquals(paginate, LogicalPlanOptimizer.create().optimize(paginate)); } @Test @@ -394,19 +392,18 @@ void push_page_size_noop_if_no_sub_plans() { @Test void table_scan_builder_support_offset_push_down_can_apply_its_rule() { - when(table.createPagedScanBuilder(anyInt())).thenReturn(pagedTableScanBuilder); + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(true); var relation = new LogicalRelation("schema", table); var optimized = LogicalPlanOptimizer.create() .optimize(new LogicalPaginate(42, List.of(project(relation)))); - // `optimized` structure: LogicalPaginate -> LogicalProject -> TableScanBuilder + // `optimized` structure: LogicalProject -> TableScanBuilder // LogicalRelation replaced by a TableScanBuilder instance - assertEquals(project(pagedTableScanBuilder), optimized); + assertEquals(project(tableScanBuilder), optimized); } private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); - final LogicalPlan optimize = optimizer.optimize(plan); - return optimize; + return optimizer.optimize(plan); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java deleted file mode 100644 index 79c7b55c60b..00000000000 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.planner.optimizer.rule; - -import static com.facebook.presto.matching.DefaultMatcher.DEFAULT_MATCHER; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.when; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; - -import java.util.List; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.storage.Table; - -@ExtendWith(MockitoExtension.class) -class CreatePagingTableScanBuilderTest { - - @Mock - LogicalPlan multiRelationPaginate; - - @Mock - Table table; - - @BeforeEach - public void setUp() { - when(multiRelationPaginate.getChild()) - .thenReturn( - List.of(relation("t1", table), relation("t2", table))); - } - - @Test - void throws_when_mutliple_children() { - final var pattern = new CreatePagingTableScanBuilder().pattern(); - final var plan = paginate(multiRelationPaginate, 42); - assertThrows(UnsupportedOperationException.class, - () -> DEFAULT_MATCHER.match(pattern, plan)); - } -} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java index 77fcb7a5054..f5ecf76bd09 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java @@ -23,13 +23,9 @@ import com.google.common.collect.ImmutableMap; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInput; import java.io.ObjectInputStream; -import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.util.List; -import lombok.EqualsAndHashCode; import lombok.SneakyThrows; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -40,7 +36,7 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.utils.TestOperator; @ExtendWith(MockitoExtension.class) class ProjectOperatorTest extends PhysicalPlanTestBase { @@ -234,36 +230,4 @@ public void serializable() { var roundTripPlan = (ProjectOperator) objectInput.readObject(); assertEquals(project, roundTripPlan); } - - @EqualsAndHashCode(callSuper = false) - public static class TestOperator extends PhysicalPlan implements SerializablePlan { - - @Override - public R accept(PhysicalPlanNodeVisitor visitor, C context) { - return null; - } - - @Override - public boolean hasNext() { - return false; - } - - @Override - public ExprValue next() { - return null; - } - - @Override - public List getChild() { - return null; - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - } - } } diff --git a/core/src/test/java/org/opensearch/sql/storage/TableTest.java b/core/src/test/java/org/opensearch/sql/storage/TableTest.java deleted file mode 100644 index a96ee71af0b..00000000000 --- a/core/src/test/java/org/opensearch/sql/storage/TableTest.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.storage; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.withSettings; - -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.mockito.invocation.InvocationOnMock; - -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class TableTest { - - @Test - public void createPagedScanBuilder_throws() { - var table = mock(Table.class, withSettings().defaultAnswer(InvocationOnMock::callRealMethod)); - assertThrows(Throwable.class, () -> table.createPagedScanBuilder(4)); - } -} diff --git a/core/src/test/java/org/opensearch/sql/utils/TestOperator.java b/core/src/test/java/org/opensearch/sql/utils/TestOperator.java new file mode 100644 index 00000000000..584cf6f3fd0 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/utils/TestOperator.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.utils; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.List; +import lombok.Setter; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +public class TestOperator extends PhysicalPlan implements SerializablePlan { + private int field; + @Setter + private boolean throwNoCursorOnWrite = false; + @Setter + private boolean throwIoOnWrite = false; + + public TestOperator() { + } + + public TestOperator(int value) { + field = value; + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + field = in.readInt(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (throwNoCursorOnWrite) { + throw new NoCursorException(); + } + if (throwIoOnWrite) { + throw new IOException(); + } + out.writeInt(field); + } + + @Override + public boolean equals(Object o) { + return field == ((TestOperator) o).field; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return null; + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public List getChild() { + return null; + } +} diff --git a/docs/dev/Pagination-v2.md b/docs/dev/Pagination-v2.md index 2416ae51de2..5379153a4a9 100644 --- a/docs/dev/Pagination-v2.md +++ b/docs/dev/Pagination-v2.md @@ -202,7 +202,7 @@ classDiagram ``` When `QueryPlanFactory.create` is passed a subsequent query request, it: -1. Creates an instance of `Cursor` unresolved plan as the sole node in the unresolved query plan. +1. Creates an instance of `FetchCursor` unresolved plan as the sole node in the unresolved query plan. ```mermaid classDiagram @@ -213,11 +213,11 @@ classDiagram -UnresolvedPlan plan -QueryService queryService } - class Cursor { + class FetchCursor { <> -String cursorId } - QueryPlan --* Cursor + QueryPlan --* FetchCursor ``` The examples below show Abstract Query Plan for the same query in different request types: @@ -256,7 +256,7 @@ stateDiagram-v2 } state "Subsequent Query Request" As Sub { - Cursor + FetchCursor } ``` @@ -284,7 +284,7 @@ classDiagram LogicalQueryPlan --* LogicalRelation ``` -For subsequent page requests, `Cursor` unresolved plan is mapped to `LogicalCursor` logical plan. +For subsequent page requests, `FetchCursor` unresolved plan is mapped to `LogicalFetchCursor` logical plan. ```mermaid classDiagram @@ -292,11 +292,11 @@ classDiagram class LogicalQueryPlan { <> } - class LogicalCursor { + class LogicalFetchCursor { <> -String cursorId } - LogicalQueryPlan --* LogicalCursor + LogicalQueryPlan --* LogicalFetchCursor ``` The examples below show logical query plan for the same query in different request types: @@ -331,7 +331,7 @@ stateDiagram-v2 } state "Subsequent Query Request" As Sub { -Cursor +FetchCursor } ``` @@ -500,31 +500,20 @@ Subsequent pages are processed by a new workflow. The key point there: ```mermaid sequenceDiagram - participant SQLService - participant QueryPlanFactory - participant QueryService - participant OpenSearchExecutionEngine - participant PlanSerializer SQLService ->>+ QueryPlanFactory : execute QueryPlanFactory ->>+ QueryService : execute - rect rgb(91, 123, 155) - note over QueryService, PlanSerializer : Deserialization - QueryService ->>+ PlanSerializer: convertToPlan - PlanSerializer -->>- QueryService: Physical Query Plan - end - Note over QueryService : Planner, Optimizer and Implementor
are skipped - QueryService ->>+ OpenSearchExecutionEngine : execute - rect rgb(91, 123, 155) - note over OpenSearchExecutionEngine, PlanSerializer : Serialization - OpenSearchExecutionEngine ->>+ PlanSerializer : convertToCursor - PlanSerializer -->>- OpenSearchExecutionEngine : cursor - end - rect rgb(91, 123, 155) - Note over OpenSearchExecutionEngine : get total hits - end - OpenSearchExecutionEngine -->>- QueryService: execution completed - QueryService -->>- QueryPlanFactory : execution completed + QueryService ->>+ Analyzer : analyze + Analyzer -->>- QueryService : new LogicalFetchCursor + QueryService ->>+ Planner : plan + Planner ->>+ DefaultImplementor : implement + DefaultImplementor ->>+ PlanSerializer : deserialize + PlanSerializer -->>- DefaultImplementor: physical query plan + DefaultImplementor -->>- Planner : physical query plan + Planner -->>- QueryService : physical query plan + QueryService ->>+ OpenSearchExecutionEngine : execute + OpenSearchExecutionEngine -->>- QueryService: execution completed + QueryService -->>- QueryPlanFactory : execution completed QueryPlanFactory -->>- SQLService : execution completed ``` @@ -614,7 +603,6 @@ sequenceDiagram participant ResourceMonitorPlan participant OpenSearchIndexScan participant OpenSearchScrollRequest - participant OpenSearchScrollRequest PlanSerializer ->>+ ProjectOperator : getPlanForSerialization ProjectOperator -->>- PlanSerializer : this diff --git a/docs/dev/query-optimizer-improvement.md b/docs/dev/query-optimizer-improvement.md index 75030875621..720649b280c 100644 --- a/docs/dev/query-optimizer-improvement.md +++ b/docs/dev/query-optimizer-improvement.md @@ -91,6 +91,7 @@ classDiagram +pushDownAggregation(LogicalAggregation) boolean +pushDownSort(LogicalSort) boolean +pushDownLimit(LogicalLimit) boolean + +pushDownPageSize(LogicalPaginate) boolean +pushDownProject(LogicalProject) boolean +pushDownHighlight(LogicalHighlight) boolean +pushDownNested(LogicalNested) boolean @@ -103,16 +104,13 @@ classDiagram +pushDownAggregation(LogicalAggregation) boolean +pushDownSort(LogicalSort) boolean +pushDownLimit(LogicalLimit) boolean + +pushDownPageSize(LogicalPaginate) boolean +pushDownProject(LogicalProject) boolean +pushDownHighlight(LogicalHighlight) boolean +pushDownNested(LogicalNested) boolean +findReferenceExpression(NamedExpression)$ List~ReferenceExpression~ +findReferenceExpressions(List~NamedExpression~)$ Set~ReferenceExpression~ } - class OpenSearchPagedIndexScanBuilder { - +OpenSearchPagedIndexScanBuilder(OpenSearchPagedIndexScan) - +build() TableScanOperator - } class OpenSearchIndexScanBuilder { -TableScanBuilder delegate -boolean isLimitPushedDown @@ -131,7 +129,6 @@ classDiagram LogicalPlan <|-- TableScanBuilder TableScanBuilder <|-- OpenSearchIndexScanQueryBuilder - TableScanBuilder <|-- OpenSearchPagedIndexScanBuilder TableScanBuilder <|-- OpenSearchIndexScanBuilder OpenSearchIndexScanBuilder *-- "1" TableScanBuilder : delegate OpenSearchIndexScanBuilder <.. OpenSearchIndexScanQueryBuilder : creates @@ -159,7 +156,6 @@ classDiagram } class Table { +TableScanBuilder createScanBuilder() - +TableScanBuilder createPagedScanBuilder(int) } class TableScanPushDown~T~ { +Rule~T~ PUSH_DOWN_FILTER$ diff --git a/docs/user/optimization/optimization.rst b/docs/user/optimization/optimization.rst index e0fe9435609..8ab998309d6 100644 --- a/docs/user/optimization/optimization.rst +++ b/docs/user/optimization/optimization.rst @@ -287,7 +287,7 @@ The Aggregation operator will merge into OpenSearch Aggregation:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } @@ -313,7 +313,7 @@ The Sort operator will merge into OpenSearch Aggregation.:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"last\",\"order\":\"desc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"last\",\"order\":\"desc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } @@ -348,7 +348,7 @@ Because the OpenSearch Composite Aggregation doesn't support order by metrics fi { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } diff --git a/docs/user/ppl/interfaces/endpoint.rst b/docs/user/ppl/interfaces/endpoint.rst index fb64eff6883..793b94eb8d0 100644 --- a/docs/user/ppl/interfaces/endpoint.rst +++ b/docs/user/ppl/interfaces/endpoint.rst @@ -91,7 +91,7 @@ The following PPL query demonstrated that where and stats command were pushed do { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" }, "children": [] } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index 595fd8acd5f..b1fcbf7d1b1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -229,19 +229,18 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl } @Provides - public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { + public PlanSerializer planSerializer(StorageEngine storageEngine) { return new PlanSerializer(storageEngine); } @Provides - public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine, - PlanSerializer planSerializer) { + public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); QueryService queryService = new QueryService(analyzer, executionEngine, planner); - return new QueryPlanFactory(queryService, planSerializer); + return new QueryPlanFactory(queryService); } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java index d8213b1fe44..2a34dabd790 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java @@ -6,14 +6,9 @@ package org.opensearch.sql.sql; -import static org.opensearch.sql.legacy.TestUtils.getResponseBody; -import static org.opensearch.sql.legacy.TestUtils.isIndexExist; -import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; - +import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; @@ -23,25 +18,30 @@ import org.junit.Test; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; -import org.opensearch.client.Request; import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; // This class has only one test case, because it is parametrized and takes significant time @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class PaginationBlackboxIT extends SQLIntegTestCase { - private final String index; + private final Index index; private final Integer pageSize; - public PaginationBlackboxIT(@Name("index") String index, + public PaginationBlackboxIT(@Name("index") Index index, @Name("pageSize") Integer pageSize) { this.index = index; this.pageSize = pageSize; } + @Override + protected void init() throws IOException { + loadIndex(index); + } + @ParametersFactory(argumentFormatting = "index = %1$s, page_size = %2$d") public static Iterable compareTwoDates() { - var indices = new PaginationBlackboxHelper().getIndices(); + var indices = List.of(Index.ACCOUNT, Index.BEER, Index.BANK); var pageSizes = List.of(5, 10, 100, 1000); var testData = new ArrayList(); for (var index : indices) { @@ -55,63 +55,47 @@ public static Iterable compareTwoDates() { @Test @SneakyThrows public void test_pagination_blackbox() { - var response = executeJdbcRequest(String.format("select * from %s", index)); + var response = executeJdbcRequest(String.format("select * from %s", index.getName())); var indexSize = response.getInt("total"); var rows = response.getJSONArray("datarows"); var schema = response.getJSONArray("schema"); - var testReportPrefix = String.format("index: %s, page size: %d || ", index, pageSize); + var testReportPrefix = String.format("index: %s, page size: %d || ", index.getName(), pageSize); var rowsPaged = new JSONArray(); var rowsReturned = 0; - response = new JSONObject(executeFetchQuery( - String.format("select * from %s", index), pageSize, "jdbc")); + var responseCounter = 1; this.logger.info(testReportPrefix + "first response"); - while (response.has("cursor")) { - assertEquals(indexSize, response.getInt("total")); + response = new JSONObject(executeFetchQuery( + String.format("select * from %s", index.getName()), pageSize, "jdbc")); + + var cursor = response.has("cursor")? response.getString("cursor") : ""; + do { + this.logger.info(testReportPrefix + + String.format("subsequent response %d/%d", responseCounter++, (indexSize / pageSize) + 1)); assertTrue("Paged response schema doesn't match to non-paged", schema.similar(response.getJSONArray("schema"))); - var cursor = response.getString("cursor"); - assertTrue(testReportPrefix + "Cursor returned from legacy engine", - cursor.startsWith("n:")); + rowsReturned += response.getInt("size"); var datarows = response.getJSONArray("datarows"); for (int i = 0; i < datarows.length(); i++) { rowsPaged.put(datarows.get(i)); } - response = executeCursorQuery(cursor); - this.logger.info(testReportPrefix - + String.format("subsequent response %d/%d", responseCounter++, (indexSize / pageSize) + 1)); - } + + if (response.has("cursor")) { + TestUtils.verifyIsV2Cursor(response); + cursor = response.getString("cursor"); + response = executeCursorQuery(cursor); + } else { + cursor = ""; + } + + } while(!cursor.isEmpty()); assertTrue("Paged response schema doesn't match to non-paged", schema.similar(response.getJSONArray("schema"))); - assertEquals(0, response.getInt("total")); - assertEquals(testReportPrefix + "Last page is not empty", - 0, response.getInt("size")); - assertEquals(testReportPrefix + "Last page is not empty", - 0, response.getJSONArray("datarows").length()); assertEquals(testReportPrefix + "Paged responses return another row count that non-paged", indexSize, rowsReturned); assertTrue(testReportPrefix + "Paged accumulated result has other rows than non-paged", rows.similar(rowsPaged)); } - - // A dummy class created, because accessing to `client()` isn't available from a static context, - // but it is needed before an instance of `PaginationBlackboxIT` is created. - private static class PaginationBlackboxHelper extends SQLIntegTestCase { - - @SneakyThrows - private List getIndices() { - initClient(); - loadIndex(Index.ACCOUNT); - loadIndex(Index.BEER); - loadIndex(Index.BANK); - if (!isIndexExist(client(), "empty")) { - executeRequest(new Request("PUT", "/empty")); - } - return Arrays.stream(getResponseBody(client().performRequest(new Request("GET", "_cat/indices?h=i")), true).split("\n")) - // exclude this index, because it is too big and extends test time too long (almost 10k docs) - .map(String::trim).filter(i -> !i.equals(TEST_INDEX_ONLINE)).collect(Collectors.toList()); - } - } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java index 724451ef658..be208cd1374 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java @@ -21,7 +21,7 @@ public void init() throws IOException { } @After - void resetParams() throws IOException { + public void resetParams() throws IOException { resetMaxResultWindow(TEST_INDEX_PHRASE); resetQuerySizeLimit(); } @@ -31,14 +31,14 @@ public void testFetchSizeLessThanMaxResultWindow() throws IOException { setMaxResultWindow(TEST_INDEX_PHRASE, 6); JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); - String cursor = ""; int numRows = 0; do { // Process response - cursor = response.getString("cursor"); + String cursor = response.getString("cursor"); numRows += response.getJSONArray("datarows").length(); response = executeCursorQuery(cursor); } while (response.has("cursor")); + numRows += response.getJSONArray("datarows").length(); var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) .getJSONArray("datarows") @@ -54,15 +54,14 @@ public void testQuerySizeLimitDoesNotEffectTotalRowsReturned() throws IOExceptio JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); assertTrue(response.getInt("size") > querySizeLimit); - String cursor = ""; int numRows = 0; do { // Process response - cursor = response.getString("cursor"); + String cursor = response.getString("cursor"); numRows += response.getJSONArray("datarows").length(); response = executeCursorQuery(cursor); } while (response.has("cursor")); - + numRows += response.getJSONArray("datarows").length(); var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) .getJSONArray("datarows") .getJSONArray(0) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java index 0095bec7cac..aad39c40744 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java @@ -26,6 +26,7 @@ import org.opensearch.common.inject.Injector; import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; @@ -98,7 +99,7 @@ public void onFailure(Exception e) { e.printStackTrace(); fail(e.getMessage()); } - }; + } // arrange { @@ -126,7 +127,7 @@ public void onFailure(Exception e) { PhysicalPlan plan = planSerializer.convertToPlan(firstResponder.getCursor().toString()); var secondResponder = new TestResponder(); - queryService.executePlan(plan, secondResponder); + queryService.execute(new FetchCursor(firstResponder.getCursor().toString()), secondResponder); // act 3: confirm that there's no cursor. } diff --git a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java index a86f2513771..e38f408514e 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java @@ -104,11 +104,9 @@ public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { } @Provides - public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine, - PlanSerializer planSerializer, - QueryService qs) { + public QueryPlanFactory queryPlanFactory(QueryService qs) { - return new QueryPlanFactory(qs, planSerializer); + return new QueryPlanFactory(qs); } @Provides diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json index 2d7f5f8c087..568b397f07b 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json @@ -8,7 +8,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json index 45988e35c7f..8d45714283d 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json @@ -31,7 +31,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" }, "children": [] } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index cbbc8c7b9cb..c48b18a609d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -102,7 +102,9 @@ public RestChannelConsumer prepareRequest( channel, createExplainResponseListener(channel, executionErrorHandler), fallbackHandler)); - } else { + } + // If close request, sqlService.closeCursor + else { return channel -> sqlService.execute( request, diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index bfc29b02d21..f63eb9e2042 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; +import java.util.Map; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; @@ -70,7 +71,7 @@ public void explain(PhysicalPlan plan, ResponseListener listene @Override public ExplainResponseNode visitTableScan(TableScanOperator node, Object context) { return explain(node, context, explainNode -> { - explainNode.setDescription(ImmutableMap.of("request", node.explain())); + explainNode.setDescription(Map.of("request", node.explain())); }); } }; @@ -81,5 +82,4 @@ public ExplainResponseNode visitTableScan(TableScanOperator node, Object context } }); } - } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java deleted file mode 100644 index 4789a50896a..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import java.util.List; -import java.util.function.Consumer; -import java.util.function.Function; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.ToString; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; - -/** - * Scroll (cursor) request is used to page the search. This request is not configurable and has - * no search query. It just handles paging through responses to the initial request. - * It is used on second and next pagination (cursor) requests. - * First (initial) request is handled by {@link InitialPageRequestBuilder}. - */ -@EqualsAndHashCode -@RequiredArgsConstructor -public class ContinuePageRequest implements OpenSearchRequest { - private final String initialScrollId; - private final TimeValue scrollTimeout; - // ScrollId that OpenSearch returns after search. - private String responseScrollId; - - @EqualsAndHashCode.Exclude - @ToString.Exclude - @Getter - private final OpenSearchExprValueFactory exprValueFactory; - - @EqualsAndHashCode.Exclude - private boolean scrollFinished = false; - - @Override - public OpenSearchResponse search(Function searchAction, - Function scrollAction) { - SearchResponse openSearchResponse = scrollAction.apply(new SearchScrollRequest(initialScrollId) - .scroll(scrollTimeout)); - - // TODO if terminated_early - something went wrong, e.g. no scroll returned. - var response = new OpenSearchResponse(openSearchResponse, exprValueFactory, List.of()); - // on the last empty page, we should close the scroll - scrollFinished = response.isEmpty(); - responseScrollId = openSearchResponse.getScrollId(); - return response; - } - - @Override - public void clean(Consumer cleanAction) { - if (scrollFinished) { - cleanAction.accept(responseScrollId); - } - } - - @Override - public SearchSourceBuilder getSourceBuilder() { - throw new UnsupportedOperationException( - "SearchSourceBuilder is unavailable for ContinueScrollRequest"); - } - - @Override - public String toCursor() { - // on the last page, we shouldn't return the scroll to user, it is kept for closing (clean) - return scrollFinished ? null : responseScrollId; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java deleted file mode 100644 index b1a6589acab..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import lombok.Getter; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.sort.SortBuilder; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; - -/** - * Builds a {@link ContinuePageRequest} to handle subsequent pagination/scroll/cursor requests. - * Initial search requests is handled by {@link InitialPageRequestBuilder}. - */ -public class ContinuePageRequestBuilder extends PagedRequestBuilder { - - @Getter - private final OpenSearchRequest.IndexName indexName; - @Getter - private final String scrollId; - private final TimeValue scrollTimeout; - private final OpenSearchExprValueFactory exprValueFactory; - - /** Constructor. */ - public ContinuePageRequestBuilder(OpenSearchRequest.IndexName indexName, - String scrollId, - Settings settings, - OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.scrollId = scrollId; - this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - this.exprValueFactory = exprValueFactory; - } - - @Override - public OpenSearchRequest build() { - return new ContinuePageRequest(scrollId, scrollTimeout, exprValueFactory); - } - - @Override - public void pushDownFilter(QueryBuilder query) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownAggregation(Pair, - OpenSearchAggregationResponseParser> aggregationBuilder) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownSort(List> sortBuilders) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownLimit(Integer limit, Integer offset) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownHighlight(String field, Map arguments) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownProjects(Set projects) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushTypeMapping(Map typeMapping) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownNested(List> nestedArgs) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } - - @Override - public void pushDownTrackedScore(boolean trackScores) { - throw new UnsupportedOperationException("Cursor requests don't support any push down"); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java deleted file mode 100644 index 25b7253ecaa..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import lombok.Getter; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortBuilder; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; - -/** - * This builder assists creating the initial OpenSearch paging (scrolling) request. - * It is used only on the first page (pagination request). - * Subsequent requests (cursor requests) use {@link ContinuePageRequestBuilder}. - */ -public class InitialPageRequestBuilder extends PagedRequestBuilder { - - @Getter - private final OpenSearchRequest.IndexName indexName; - private final SearchSourceBuilder sourceBuilder; - private final OpenSearchExprValueFactory exprValueFactory; - private final TimeValue scrollTimeout; - - /** - * Constructor. - * @param indexName index being scanned - * @param pageSize page size - * @param exprValueFactory value factory - */ - // TODO accept indexName as string (same way as `OpenSearchRequestBuilder` does)? - public InitialPageRequestBuilder(OpenSearchRequest.IndexName indexName, - int pageSize, - Settings settings, - OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.exprValueFactory = exprValueFactory; - this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - this.sourceBuilder = new SearchSourceBuilder() - .from(0) - .size(pageSize) - .timeout(DEFAULT_QUERY_TIMEOUT); - } - - @Override - public OpenSearchScrollRequest build() { - return new OpenSearchScrollRequest(indexName, scrollTimeout, sourceBuilder, exprValueFactory); - } - - @Override - public void pushDownFilter(QueryBuilder query) { - throw new UnsupportedOperationException("Pagination does not support filter (WHERE clause)"); - } - - @Override - public void pushDownAggregation(Pair, - OpenSearchAggregationResponseParser> aggregationBuilder) { - throw new UnsupportedOperationException("Pagination does not support aggregations"); - } - - @Override - public void pushDownSort(List> sortBuilders) { - throw new UnsupportedOperationException("Pagination does not support sort (ORDER BY clause)"); - } - - @Override - public void pushDownLimit(Integer limit, Integer offset) { - throw new UnsupportedOperationException("Pagination does not support limit (LIMIT clause)"); - } - - @Override - public void pushDownHighlight(String field, Map arguments) { - throw new UnsupportedOperationException("Pagination does not support highlight function"); - } - - /** - * Push down project expression to OpenSearch. - */ - @Override - public void pushDownProjects(Set projects) { - sourceBuilder.fetchSource(projects.stream().map(ReferenceExpression::getAttr) - .distinct().toArray(String[]::new), new String[0]); - } - - @Override - public void pushTypeMapping(Map typeMapping) { - exprValueFactory.extendTypeMapping(typeMapping); - } - - @Override - public void pushDownNested(List> nestedArgs) { - throw new UnsupportedOperationException("Pagination does not support nested function"); - } - - @Override - public void pushDownTrackedScore(boolean trackScores) { - throw new UnsupportedOperationException("Pagination does not support score function"); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 63aeed02f03..45954a38718 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -6,9 +6,7 @@ package org.opensearch.sql.opensearch.request; -import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; - -import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Consumer; @@ -19,6 +17,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -47,6 +46,7 @@ public class OpenSearchQueryRequest implements OpenSearchRequest { private final SearchSourceBuilder sourceBuilder; + /** * OpenSearchExprValueFactory. */ @@ -102,7 +102,9 @@ public OpenSearchResponse search(Function searchA } else { searchDone = true; return new OpenSearchResponse( - searchAction.apply(searchRequest()), exprValueFactory, includes); + searchAction.apply(new SearchRequest() + .indices(indexName.getIndexNames()) + .source(sourceBuilder)), exprValueFactory, includes); } } @@ -111,15 +113,14 @@ public void clean(Consumer cleanAction) { //do nothing. } - /** - * Generate OpenSearch search request. - * - * @return search request - */ - @VisibleForTesting - protected SearchRequest searchRequest() { - return new SearchRequest() - .indices(indexName.getIndexNames()) - .source(sourceBuilder); + @Override + public boolean hasAnotherBatch() { + return false; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("OpenSearchQueryRequest serialization " + + "is not implemented."); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java index c5b6d60af36..e6fe9f32e47 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java @@ -6,20 +6,29 @@ package org.opensearch.sql.opensearch.request; +import java.io.IOException; import java.util.function.Consumer; import java.util.function.Function; import lombok.EqualsAndHashCode; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; /** * OpenSearch search request. */ -public interface OpenSearchRequest { +public interface OpenSearchRequest extends Writeable { + /** + * Default query timeout in minutes. + */ + TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + /** * Apply the search action or scroll action on request based on context. * @@ -37,33 +46,28 @@ OpenSearchResponse search(Function searchAction, */ void clean(Consumer cleanAction); - /** - * Get the SearchSourceBuilder. - * - * @return SearchSourceBuilder. - */ - SearchSourceBuilder getSourceBuilder(); - /** * Get the ElasticsearchExprValueFactory. * @return ElasticsearchExprValueFactory. */ OpenSearchExprValueFactory getExprValueFactory(); - default String toCursor() { - return ""; - } + boolean hasAnotherBatch(); /** * OpenSearch Index Name. * Indices are separated by ",". */ @EqualsAndHashCode - class IndexName { + class IndexName implements Writeable { private static final String COMMA = ","; private final String[] indexNames; + public IndexName(StreamInput si) throws IOException { + indexNames = si.readStringArray(); + } + public IndexName(String indexName) { this.indexNames = indexName.split(COMMA); } @@ -76,5 +80,10 @@ public String[] getIndexNames() { public String toString() { return String.join(COMMA, indexNames); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringArray(indexNames); + } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index f8d62ad7ce6..97512bec496 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -8,13 +8,11 @@ import static java.util.stream.Collectors.mapping; import static java.util.stream.Collectors.toList; -import static org.opensearch.index.query.QueryBuilders.boolQuery; import static org.opensearch.index.query.QueryBuilders.matchAllQuery; import static org.opensearch.index.query.QueryBuilders.nestedQuery; import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -37,42 +35,36 @@ import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.planner.logical.LogicalNested; + /** * OpenSearch search request builder. */ -@EqualsAndHashCode(callSuper = false) +@EqualsAndHashCode @Getter @ToString -public class OpenSearchRequestBuilder implements PushDownRequestBuilder { - - /** - * Default query timeout in minutes. - */ - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); +public class OpenSearchRequestBuilder { /** - * {@link OpenSearchRequest.IndexName}. + * Search request source builder. */ - private final OpenSearchRequest.IndexName indexName; + private final SearchSourceBuilder sourceBuilder; /** - * Index max result window. + * Query size of the request -- how many rows will be returned. */ - private final Integer maxResultWindow; + private int requestedTotalSize; /** - * Search request source builder. + * Size of each page request to return. */ - private final SearchSourceBuilder sourceBuilder; + private Integer pageSize = null; /** * OpenSearchExprValueFactory. @@ -80,42 +72,19 @@ public class OpenSearchRequestBuilder implements PushDownRequestBuilder { @EqualsAndHashCode.Exclude @ToString.Exclude private final OpenSearchExprValueFactory exprValueFactory; - - /** - * Query size of the request -- how many rows will be returned. - */ - private int querySize; - - /** - * Scroll context life time. - */ - private final TimeValue scrollTimeout; - - public OpenSearchRequestBuilder(String indexName, - Integer maxResultWindow, - Settings settings, - OpenSearchExprValueFactory exprValueFactory) { - this(new OpenSearchRequest.IndexName(indexName), maxResultWindow, settings, - exprValueFactory); - } + private int startFrom = 0; /** * Constructor. */ - public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, - Integer maxResultWindow, - Settings settings, + public OpenSearchRequestBuilder(int requestedTotalSize, OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.maxResultWindow = maxResultWindow; - this.exprValueFactory = exprValueFactory; - this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - this.querySize = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); + this.requestedTotalSize = requestedTotalSize; this.sourceBuilder = new SearchSourceBuilder() - .from(0) - .size(querySize) - .timeout(DEFAULT_QUERY_TIMEOUT) + .from(startFrom) + .timeout(OpenSearchRequest.DEFAULT_QUERY_TIMEOUT) .trackScores(false); + this.exprValueFactory = exprValueFactory; } /** @@ -123,25 +92,39 @@ public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, * * @return query request or scroll request */ - public OpenSearchRequest build() { - Integer from = sourceBuilder.from(); - Integer size = sourceBuilder.size(); - - if (from + size > maxResultWindow) { - sourceBuilder.size(maxResultWindow - from); - return new OpenSearchScrollRequest( - indexName, scrollTimeout, sourceBuilder, exprValueFactory); + public OpenSearchRequest build(OpenSearchRequest.IndexName indexName, + int maxResultWindow, TimeValue scrollTimeout) { + int size = requestedTotalSize; + if (pageSize == null) { + if (startFrom + size > maxResultWindow) { + sourceBuilder.size(maxResultWindow - startFrom); + return new OpenSearchScrollRequest( + indexName, scrollTimeout, sourceBuilder, exprValueFactory); + } else { + sourceBuilder.from(startFrom); + sourceBuilder.size(requestedTotalSize); + return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); + } } else { - return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); + if (startFrom != 0) { + throw new UnsupportedOperationException("Non-zero offset is not supported with pagination"); + } + sourceBuilder.size(pageSize); + return new OpenSearchScrollRequest(indexName, scrollTimeout, + sourceBuilder, exprValueFactory); } } + + boolean isBoolFilterQuery(QueryBuilder current) { + return (current instanceof BoolQueryBuilder); + } + /** * Push down query to DSL request. * * @param query query request */ - @Override public void pushDownFilter(QueryBuilder query) { QueryBuilder current = sourceBuilder.query(); @@ -167,10 +150,9 @@ public void pushDownFilter(QueryBuilder query) { * * @param aggregationBuilder pair of aggregation query and aggregation parser. */ - @Override public void pushDownAggregation( Pair, OpenSearchAggregationResponseParser> aggregationBuilder) { - aggregationBuilder.getLeft().forEach(builder -> sourceBuilder.aggregation(builder)); + aggregationBuilder.getLeft().forEach(sourceBuilder::aggregation); sourceBuilder.size(0); exprValueFactory.setParser(aggregationBuilder.getRight()); } @@ -180,7 +162,6 @@ public void pushDownAggregation( * * @param sortBuilders sortBuilders. */ - @Override public void pushDownSort(List> sortBuilders) { // TODO: Sort by _doc is added when filter push down. Remove both logic once doctest fixed. if (isSortByDocOnly()) { @@ -193,24 +174,26 @@ public void pushDownSort(List> sortBuilders) { } /** - * Push down size (limit) and from (offset) to DSL request. + * Pushdown size (limit) and from (offset) to DSL request. */ - @Override public void pushDownLimit(Integer limit, Integer offset) { - querySize = limit; + requestedTotalSize = limit; + startFrom = offset; sourceBuilder.from(offset).size(limit); } - @Override public void pushDownTrackedScore(boolean trackScores) { sourceBuilder.trackScores(trackScores); } + public void pushDownPageSize(int pageSize) { + this.pageSize = pageSize; + } + /** * Add highlight to DSL requests. * @param field name of the field to highlight */ - @Override public void pushDownHighlight(String field, Map arguments) { String unquotedField = StringUtils.unquoteText(field); if (sourceBuilder.highlighter() != null) { @@ -243,14 +226,12 @@ public void pushDownHighlight(String field, Map arguments) { /** * Push down project list to DSL requests. */ - @Override public void pushDownProjects(Set projects) { - final Set projectsSet = - projects.stream().map(ReferenceExpression::getAttr).collect(Collectors.toSet()); - sourceBuilder.fetchSource(projectsSet.toArray(new String[0]), new String[0]); + sourceBuilder.fetchSource( + projects.stream().map(ReferenceExpression::getAttr).distinct().toArray(String[]::new), + new String[0]); } - @Override public void pushTypeMapping(Map typeMapping) { exprValueFactory.extendTypeMapping(typeMapping); } @@ -258,7 +239,7 @@ public void pushTypeMapping(Map typeMapping) { private boolean isSortByDocOnly() { List> sorts = sourceBuilder.sorts(); if (sorts != null) { - return sorts.equals(Arrays.asList(SortBuilders.fieldSort(DOC_FIELD_NAME))); + return sorts.equals(List.of(SortBuilders.fieldSort(DOC_FIELD_NAME))); } return false; } @@ -267,7 +248,6 @@ private boolean isSortByDocOnly() { * Push down nested to sourceBuilder. * @param nestedArgs : Nested arguments to push down. */ - @Override public void pushDownNested(List> nestedArgs) { initBoolQueryFilter(); groupFieldNamesByPath(nestedArgs).forEach( @@ -277,6 +257,10 @@ fieldNames, createEmptyNestedQuery(path) ); } + public int getMaxResponseSize() { + return pageSize == null ? requestedTotalSize : pageSize; + } + /** * Initialize bool query for push down. */ diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java index 77c6a781fe9..7173eff171a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.request; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -18,11 +19,14 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; /** * OpenSearch scroll search request. This has to be stateful because it needs to: @@ -34,7 +38,7 @@ @Getter @ToString public class OpenSearchScrollRequest implements OpenSearchRequest { - + private final SearchRequest initialSearchRequest; /** Scroll context timeout. */ private final TimeValue scrollTimeout; @@ -47,19 +51,20 @@ public class OpenSearchScrollRequest implements OpenSearchRequest { @EqualsAndHashCode.Exclude @ToString.Exclude private final OpenSearchExprValueFactory exprValueFactory; - /** * Scroll id which is set after first request issued. Because ElasticsearchClient is shared by * multi-thread so this state has to be maintained here. */ @Setter @Getter - private String scrollId; + private String scrollId = NO_SCROLL_ID; + + public static final String NO_SCROLL_ID = ""; private boolean needClean = false; - /** Search request source builder. */ - private final SearchSourceBuilder sourceBuilder; + @Getter + private final List includes; /** Constructor. */ public OpenSearchScrollRequest(IndexName indexName, @@ -68,11 +73,20 @@ public OpenSearchScrollRequest(IndexName indexName, OpenSearchExprValueFactory exprValueFactory) { this.indexName = indexName; this.scrollTimeout = scrollTimeout; - this.sourceBuilder = sourceBuilder; this.exprValueFactory = exprValueFactory; + this.initialSearchRequest = new SearchRequest() + .indices(indexName.getIndexNames()) + .scroll(scrollTimeout) + .source(sourceBuilder); + + includes = sourceBuilder.fetchSource() == null + ? List.of() + : Arrays.asList(sourceBuilder.fetchSource().includes()); } - /** Constructor. */ + + /** Executes request using either {@param searchAction} or {@param scrollAction} as appropriate. + */ @Override public OpenSearchResponse search(Function searchAction, Function scrollAction) { @@ -80,15 +94,12 @@ public OpenSearchResponse search(Function searchA if (isScroll()) { openSearchResponse = scrollAction.apply(scrollRequest()); } else { - openSearchResponse = searchAction.apply(searchRequest()); + openSearchResponse = searchAction.apply(initialSearchRequest); } - FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); - List includes = fetchSource != null && fetchSource.includes() != null - ? Arrays.asList(this.sourceBuilder.fetchSource().includes()) - : List.of(); var response = new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); - if (!(needClean = response.isEmpty())) { + needClean = response.isEmpty(); + if (!needClean) { setScrollId(openSearchResponse.getScrollId()); } return response; @@ -100,32 +111,20 @@ public void clean(Consumer cleanAction) { // clean on the last page only, to prevent closing the scroll/cursor in the middle of paging. if (needClean && isScroll()) { cleanAction.accept(getScrollId()); - setScrollId(null); + setScrollId(NO_SCROLL_ID); } } finally { reset(); } } - /** - * Generate OpenSearch search request. - * - * @return search request - */ - public SearchRequest searchRequest() { - return new SearchRequest() - .indices(indexName.getIndexNames()) - .scroll(scrollTimeout) - .source(sourceBuilder); - } - /** * Is scroll started which means pages after first is being requested. * * @return true if scroll started */ public boolean isScroll() { - return scrollId != null; + return !scrollId.equals(NO_SCROLL_ID); } /** @@ -143,7 +142,7 @@ public SearchScrollRequest scrollRequest() { * to be reused across different physical plan. */ public void reset() { - scrollId = null; + scrollId = NO_SCROLL_ID; } /** @@ -151,7 +150,42 @@ public void reset() { * @return a string representing the scroll request. */ @Override - public String toCursor() { - return scrollId; + public boolean hasAnotherBatch() { + return !needClean && !scrollId.equals(NO_SCROLL_ID); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + initialSearchRequest.writeTo(out); + out.writeTimeValue(scrollTimeout); + out.writeBoolean(needClean); + if (!needClean) { + // If needClean is true, there is no more data to get from OpenSearch and scrollId is + // used only to clean up OpenSearch context. + + out.writeString(scrollId); + } + out.writeStringCollection(includes); + indexName.writeTo(out); + } + + /** + * Constructs OpenSearchScrollRequest from serialized representation. + * @param in stream to read data from. + * @param engine OpenSearchSqlEngine to get node-specific context. + * @throws IOException thrown if reading from input {@param in} fails. + */ + public OpenSearchScrollRequest(StreamInput in, OpenSearchStorageEngine engine) + throws IOException { + initialSearchRequest = new SearchRequest(in); + scrollTimeout = in.readTimeValue(); + needClean = in.readBoolean(); + if (!needClean) { + scrollId = in.readString(); + } + includes = in.readStringList(); + indexName = new IndexName(in); + OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString()); + exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes()); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java deleted file mode 100644 index 69309bd7c9b..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java +++ /dev/null @@ -1,12 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -public abstract class PagedRequestBuilder implements PushDownRequestBuilder { - public abstract OpenSearchRequest build(); - - public abstract OpenSearchRequest.IndexName getIndexName(); -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java deleted file mode 100644 index 59aa1949b6a..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import lombok.Getter; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.sort.SortBuilder; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; - -public interface PushDownRequestBuilder { - - default boolean isBoolFilterQuery(QueryBuilder current) { - return (current instanceof BoolQueryBuilder); - } - - void pushDownFilter(QueryBuilder query); - - void pushDownAggregation(Pair, - OpenSearchAggregationResponseParser> aggregationBuilder); - - void pushDownSort(List> sortBuilders); - - void pushDownLimit(Integer limit, Integer offset); - - void pushDownHighlight(String field, Map arguments); - - void pushDownProjects(Set projects); - - void pushTypeMapping(Map typeMapping); - - void pushDownNested(List> nestedArgs); - - void pushDownTrackedScore(boolean trackScores); -} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 949b1e53ecb..532d62333da 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -11,6 +11,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.RequiredArgsConstructor; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -20,14 +21,11 @@ import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; -import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; -import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScan; -import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScanBuilder; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalML; @@ -35,6 +33,7 @@ import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; /** OpenSearch table (index) implementation. */ @@ -169,27 +168,29 @@ public PhysicalPlan implement(LogicalPlan plan) { } @Override - public LogicalPlan optimize(LogicalPlan plan) { - // No-op because optimization already done in Planner - return plan; + public TableScanBuilder createScanBuilder() { + final int querySizeLimit = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); + + var builder = new OpenSearchRequestBuilder( + querySizeLimit, + createExprValueFactory()); + + return new OpenSearchIndexScanBuilder(builder) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder requestBuilder) { + final TimeValue cursorKeepAlive = + settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + return new OpenSearchIndexScan(client, requestBuilder.getMaxResponseSize(), + requestBuilder.build(indexName, getMaxResultWindow(), cursorKeepAlive)); + } + }; } - @Override - public TableScanBuilder createScanBuilder() { + private OpenSearchExprValueFactory createExprValueFactory() { Map allFields = new HashMap<>(); getReservedFieldTypes().forEach((k, v) -> allFields.put(k, OpenSearchDataType.of(v))); allFields.putAll(getFieldOpenSearchTypes()); - OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, - getMaxResultWindow(), new OpenSearchExprValueFactory(allFields)); - return new OpenSearchIndexScanBuilder(indexScan); - } - - @Override - public TableScanBuilder createPagedScanBuilder(int pageSize) { - var requestBuilder = new InitialPageRequestBuilder(indexName, pageSize, settings, - new OpenSearchExprValueFactory(getFieldOpenSearchTypes())); - var indexScan = new OpenSearchPagedIndexScan(client, requestBuilder); - return new OpenSearchPagedIndexScanBuilder(indexScan); + return new OpenSearchExprValueFactory(allFields); } @VisibleForTesting diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java index 2171fb564f3..3633e45449a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java @@ -6,18 +6,25 @@ package org.opensearch.sql.opensearch.storage.scan; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.Collections; import java.util.Iterator; import lombok.EqualsAndHashCode; -import lombok.Getter; import lombok.ToString; -import org.opensearch.sql.common.setting.Settings; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchRequest; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.storage.TableScanOperator; /** @@ -25,26 +32,20 @@ */ @EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) @ToString(onlyExplicitlyIncluded = true) -public class OpenSearchIndexScan extends TableScanOperator { +public class OpenSearchIndexScan extends TableScanOperator implements SerializablePlan { /** OpenSearch client. */ - private final OpenSearchClient client; - - /** Search request builder. */ - @EqualsAndHashCode.Include - @Getter - @ToString.Include - private final OpenSearchRequestBuilder requestBuilder; + private OpenSearchClient client; /** Search request. */ @EqualsAndHashCode.Include @ToString.Include private OpenSearchRequest request; - /** Total query size. */ + /** Largest number of rows allowed in the response. */ @EqualsAndHashCode.Include @ToString.Include - private Integer querySize; + private int maxResponseSize; /** Number of rows returned. */ private Integer queryCount; @@ -53,36 +54,19 @@ public class OpenSearchIndexScan extends TableScanOperator { private Iterator iterator; /** - * Constructor. - */ - public OpenSearchIndexScan(OpenSearchClient client, Settings settings, - String indexName, Integer maxResultWindow, - OpenSearchExprValueFactory exprValueFactory) { - this( - client, - settings, - new OpenSearchRequest.IndexName(indexName), - maxResultWindow, - exprValueFactory - ); - } - - /** - * Constructor. + * Creates index scan based on a provided OpenSearchRequestBuilder. */ - public OpenSearchIndexScan(OpenSearchClient client, Settings settings, - OpenSearchRequest.IndexName indexName, Integer maxResultWindow, - OpenSearchExprValueFactory exprValueFactory) { + public OpenSearchIndexScan(OpenSearchClient client, + int maxResponseSize, + OpenSearchRequest request) { this.client = client; - this.requestBuilder = new OpenSearchRequestBuilder( - indexName, maxResultWindow, settings, exprValueFactory); + this.maxResponseSize = maxResponseSize; + this.request = request; } @Override public void open() { super.open(); - querySize = requestBuilder.getQuerySize(); - request = requestBuilder.build(); iterator = Collections.emptyIterator(); queryCount = 0; fetchNextBatch(); @@ -90,7 +74,7 @@ public void open() { @Override public boolean hasNext() { - if (queryCount >= querySize) { + if (queryCount >= maxResponseSize) { iterator = Collections.emptyIterator(); } else if (!iterator.hasNext()) { fetchNextBatch(); @@ -126,6 +110,51 @@ public void close() { @Override public String explain() { - return getRequestBuilder().build().toString(); + return request.toString(); + } + + /** No-args constructor. + * @deprecated Exists only to satisfy Java serialization API. + */ + @Deprecated(since = "introduction") + public OpenSearchIndexScan() { + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + int reqSize = in.readInt(); + byte[] requestStream = new byte[reqSize]; + in.read(requestStream); + + var engine = (OpenSearchStorageEngine) ((PlanSerializer.CursorDeserializationStream) in) + .resolveObject("engine"); + + try (BytesStreamInput bsi = new BytesStreamInput(requestStream)) { + request = new OpenSearchScrollRequest(bsi, engine); + } + maxResponseSize = in.readInt(); + + client = engine.getClient(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (!request.hasAnotherBatch()) { + throw new NoCursorException(); + } + // request is not directly Serializable so.. + // 1. Serialize request to an opensearch byte stream. + BytesStreamOutput reqOut = new BytesStreamOutput(); + request.writeTo(reqOut); + reqOut.flush(); + + // 2. Extract byte[] from the opensearch byte stream + var reqAsBytes = reqOut.bytes().toBytesRef().bytes; + + // 3. Write out the byte[] to object output stream. + out.writeInt(reqAsBytes.length); + out.write(reqAsBytes); + + out.writeInt(maxResponseSize); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java index 74be670dcc1..84883b52092 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.ast.tree.Sort; @@ -15,57 +16,60 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.storage.TableScanOperator; -import org.opensearch.sql.storage.read.TableScanBuilder; /** * Index scan builder for aggregate query used by {@link OpenSearchIndexScanBuilder} internally. */ -class OpenSearchIndexScanAggregationBuilder extends TableScanBuilder { +@EqualsAndHashCode +class OpenSearchIndexScanAggregationBuilder implements PushDownQueryBuilder { /** OpenSearch index scan to be optimized. */ - private final OpenSearchIndexScan indexScan; + private final OpenSearchRequestBuilder requestBuilder; /** Aggregators pushed down. */ - private List aggregatorList; + private final List aggregatorList; /** Grouping items pushed down. */ - private List groupByList; + private final List groupByList; /** Sorting items pushed down. */ private List> sortList; - /** - * Initialize with given index scan and perform push-down optimization later. - * - * @param indexScan index scan not fully optimized yet - */ - OpenSearchIndexScanAggregationBuilder(OpenSearchIndexScan indexScan) { - this.indexScan = indexScan; + + OpenSearchIndexScanAggregationBuilder(OpenSearchRequestBuilder requestBuilder, + LogicalAggregation aggregation) { + this.requestBuilder = requestBuilder; + aggregatorList = aggregation.getAggregatorList(); + groupByList = aggregation.getGroupByList(); } @Override - public TableScanOperator build() { + public OpenSearchRequestBuilder build() { AggregationQueryBuilder builder = new AggregationQueryBuilder(new DefaultExpressionSerializer()); Pair, OpenSearchAggregationResponseParser> aggregationBuilder = builder.buildAggregationBuilder(aggregatorList, groupByList, sortList); - indexScan.getRequestBuilder().pushDownAggregation(aggregationBuilder); - indexScan.getRequestBuilder().pushTypeMapping( + requestBuilder.pushDownAggregation(aggregationBuilder); + requestBuilder.pushTypeMapping( builder.buildTypeMapping(aggregatorList, groupByList)); - return indexScan; + return requestBuilder; } @Override - public boolean pushDownAggregation(LogicalAggregation aggregation) { - aggregatorList = aggregation.getAggregatorList(); - groupByList = aggregation.getGroupByList(); - return true; + public boolean pushDownFilter(LogicalFilter filter) { + return false; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 024331d267b..c6df692095e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -5,14 +5,15 @@ package org.opensearch.sql.opensearch.storage.scan; -import com.google.common.annotations.VisibleForTesting; import lombok.EqualsAndHashCode; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; import org.opensearch.sql.storage.TableScanOperator; @@ -23,36 +24,39 @@ * by delegated builder internally. This is to avoid conditional check of different push down logic * for non-aggregate and aggregate query everywhere. */ -public class OpenSearchIndexScanBuilder extends TableScanBuilder { +public abstract class OpenSearchIndexScanBuilder extends TableScanBuilder { /** * Delegated index scan builder for non-aggregate or aggregate query. */ @EqualsAndHashCode.Include - private TableScanBuilder delegate; + private PushDownQueryBuilder delegate; /** Is limit operator pushed down. */ private boolean isLimitPushedDown = false; - @VisibleForTesting - OpenSearchIndexScanBuilder(TableScanBuilder delegate) { - this.delegate = delegate; + /** + * Constructor used during query execution. + */ + protected OpenSearchIndexScanBuilder(OpenSearchRequestBuilder requestBuilder) { + this.delegate = new OpenSearchIndexScanQueryBuilder(requestBuilder); + } /** - * Initialize with given index scan. - * - * @param indexScan index scan to optimize + * Constructor used for unit tests. */ - public OpenSearchIndexScanBuilder(OpenSearchIndexScan indexScan) { - this.delegate = new OpenSearchIndexScanQueryBuilder(indexScan); + protected OpenSearchIndexScanBuilder(PushDownQueryBuilder translator) { + this.delegate = translator; } @Override public TableScanOperator build() { - return delegate.build(); + return createScan(delegate.build()); } + protected abstract TableScanOperator createScan(OpenSearchRequestBuilder requestBuilder); + @Override public boolean pushDownFilter(LogicalFilter filter) { return delegate.pushDownFilter(filter); @@ -66,10 +70,13 @@ public boolean pushDownAggregation(LogicalAggregation aggregation) { // Switch to builder for aggregate query which has different push down logic // for later filter, sort and limit operator. - delegate = new OpenSearchIndexScanAggregationBuilder( - (OpenSearchIndexScan) delegate.build()); + delegate = new OpenSearchIndexScanAggregationBuilder(delegate.build(), aggregation); + return true; + } - return delegate.pushDownAggregation(aggregation); + @Override + public boolean pushDownPageSize(LogicalPaginate paginate) { + return delegate.pushDownPageSize(paginate); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java index d9b4e6b4e02..590272a9f1f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; @@ -29,34 +30,22 @@ import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.storage.TableScanOperator; -import org.opensearch.sql.storage.read.TableScanBuilder; /** * Index scan builder for simple non-aggregate query used by * {@link OpenSearchIndexScanBuilder} internally. */ @VisibleForTesting -class OpenSearchIndexScanQueryBuilder extends TableScanBuilder { +@EqualsAndHashCode +class OpenSearchIndexScanQueryBuilder implements PushDownQueryBuilder { - /** OpenSearch index scan to be optimized. */ - @EqualsAndHashCode.Include - private final OpenSearchIndexScan indexScan; + OpenSearchRequestBuilder requestBuilder; - /** - * Initialize with given index scan and perform push-down optimization later. - * - * @param indexScan index scan not optimized yet - */ - OpenSearchIndexScanQueryBuilder(OpenSearchIndexScan indexScan) { - this.indexScan = indexScan; - } - - @Override - public TableScanOperator build() { - return indexScan; + public OpenSearchIndexScanQueryBuilder(OpenSearchRequestBuilder requestBuilder) { + this.requestBuilder = requestBuilder; } @Override @@ -65,8 +54,8 @@ public boolean pushDownFilter(LogicalFilter filter) { new DefaultExpressionSerializer()); Expression queryCondition = filter.getCondition(); QueryBuilder query = queryBuilder.build(queryCondition); - indexScan.getRequestBuilder().pushDownFilter(query); - indexScan.getRequestBuilder().pushDownTrackedScore( + requestBuilder.pushDownFilter(query); + requestBuilder.pushDownTrackedScore( trackScoresFromOpenSearchFunction(queryCondition)); return true; } @@ -75,7 +64,7 @@ public boolean pushDownFilter(LogicalFilter filter) { public boolean pushDownSort(LogicalSort sort) { List> sortList = sort.getSortList(); final SortQueryBuilder builder = new SortQueryBuilder(); - indexScan.getRequestBuilder().pushDownSort(sortList.stream() + requestBuilder.pushDownSort(sortList.stream() .map(sortItem -> builder.build(sortItem.getValue(), sortItem.getKey())) .collect(Collectors.toList())); return true; @@ -83,13 +72,13 @@ public boolean pushDownSort(LogicalSort sort) { @Override public boolean pushDownLimit(LogicalLimit limit) { - indexScan.getRequestBuilder().pushDownLimit(limit.getLimit(), limit.getOffset()); + requestBuilder.pushDownLimit(limit.getLimit(), limit.getOffset()); return true; } @Override public boolean pushDownProject(LogicalProject project) { - indexScan.getRequestBuilder().pushDownProjects( + requestBuilder.pushDownProjects( findReferenceExpressions(project.getProjectList())); // Return false intentionally to keep the original project operator @@ -98,12 +87,18 @@ public boolean pushDownProject(LogicalProject project) { @Override public boolean pushDownHighlight(LogicalHighlight highlight) { - indexScan.getRequestBuilder().pushDownHighlight( + requestBuilder.pushDownHighlight( StringUtils.unquoteText(highlight.getHighlightField().toString()), highlight.getArguments()); return true; } + @Override + public boolean pushDownPageSize(LogicalPaginate paginate) { + requestBuilder.pushDownPageSize(paginate.getPageSize()); + return true; + } + private boolean trackScoresFromOpenSearchFunction(Expression condition) { if (condition instanceof OpenSearchFunctions.OpenSearchFunction && ((OpenSearchFunctions.OpenSearchFunction) condition).isScoreTracked()) { @@ -118,8 +113,8 @@ private boolean trackScoresFromOpenSearchFunction(Expression condition) { @Override public boolean pushDownNested(LogicalNested nested) { - indexScan.getRequestBuilder().pushDownNested(nested.getFields()); - indexScan.getRequestBuilder().pushDownProjects( + requestBuilder.pushDownNested(nested.getFields()); + requestBuilder.pushDownProjects( findReferenceExpressions(nested.getProjectList())); // Return false intentionally to keep the original nested operator // Since we return false we need to pushDownProject here as it won't be @@ -128,11 +123,16 @@ public boolean pushDownNested(LogicalNested nested) { return false; } + @Override + public OpenSearchRequestBuilder build() { + return requestBuilder; + } + /** * Find reference expression from expression. * @param expressions a list of expression. * - * @return a list of ReferenceExpression + * @return a set of ReferenceExpression */ public static Set findReferenceExpressions( List expressions) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java deleted file mode 100644 index 3667a3ffdfc..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Collections; -import java.util.Iterator; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; -import org.apache.commons.lang3.NotImplementedException; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.NoCursorException; -import org.opensearch.sql.executor.pagination.PlanSerializer; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.request.ContinuePageRequestBuilder; -import org.opensearch.sql.opensearch.request.OpenSearchRequest; -import org.opensearch.sql.opensearch.request.PagedRequestBuilder; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; -import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; -import org.opensearch.sql.planner.SerializablePlan; -import org.opensearch.sql.storage.TableScanOperator; - -@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) -@ToString(onlyExplicitlyIncluded = true) -public class OpenSearchPagedIndexScan extends TableScanOperator implements SerializablePlan { - private OpenSearchClient client; - @Getter - private PagedRequestBuilder requestBuilder; - @EqualsAndHashCode.Include - @ToString.Include - private OpenSearchRequest request; - private Iterator iterator; - private long totalHits = 0; - - public OpenSearchPagedIndexScan(OpenSearchClient client, PagedRequestBuilder requestBuilder) { - this.client = client; - this.requestBuilder = requestBuilder; - } - - @Override - public String explain() { - throw new NotImplementedException("Implement OpenSearchPagedIndexScan.explain"); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - - @Override - public ExprValue next() { - return iterator.next(); - } - - @Override - public void open() { - super.open(); - request = requestBuilder.build(); - OpenSearchResponse response = client.search(request); - if (!response.isEmpty()) { - iterator = response.iterator(); - totalHits = response.getTotalHits(); - } else { - iterator = Collections.emptyIterator(); - } - } - - @Override - public void close() { - super.close(); - client.cleanup(request); - } - - @Override - public long getTotalHits() { - return totalHits; - } - - /** Don't use, it is for deserialization needs only. */ - @Deprecated - public OpenSearchPagedIndexScan() { - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - var engine = (OpenSearchStorageEngine) ((PlanSerializer.CursorDeserializationStream) in) - .resolveObject("engine"); - var indexName = (String) in.readUTF(); - var scrollId = (String) in.readUTF(); - client = engine.getClient(); - var index = new OpenSearchIndex(client, engine.getSettings(), indexName); - requestBuilder = new ContinuePageRequestBuilder( - new OpenSearchRequest.IndexName(indexName), - scrollId, engine.getSettings(), - new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes())); - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - if (request.toCursor() == null || request.toCursor().isEmpty()) { - throw new NoCursorException(); - } - - out.writeUTF(requestBuilder.getIndexName().toString()); - out.writeUTF(request.toCursor()); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java deleted file mode 100644 index 779df4ebec9..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import lombok.EqualsAndHashCode; -import org.opensearch.sql.storage.TableScanOperator; -import org.opensearch.sql.storage.read.TableScanBuilder; - -/** - * Builder for a paged OpenSearch request. - * Override pushDown* methods from TableScanBuilder as more features - * support pagination. - */ -public class OpenSearchPagedIndexScanBuilder extends TableScanBuilder { - @EqualsAndHashCode.Include - OpenSearchPagedIndexScan indexScan; - - public OpenSearchPagedIndexScanBuilder(OpenSearchPagedIndexScan indexScan) { - this.indexScan = indexScan; - } - - @Override - public TableScanOperator build() { - return indexScan; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java new file mode 100644 index 00000000000..274bc4647d9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +/** + * Translates a logical query plan into OpenSearch DSL and an appropriate request. + */ +public interface PushDownQueryBuilder { + default boolean pushDownFilter(LogicalFilter filter) { + return false; + } + + default boolean pushDownSort(LogicalSort sort) { + return false; + } + + default boolean pushDownLimit(LogicalLimit limit) { + return false; + } + + default boolean pushDownProject(LogicalProject project) { + return false; + } + + default boolean pushDownHighlight(LogicalHighlight highlight) { + return false; + } + + default boolean pushDownPageSize(LogicalPaginate paginate) { + return false; + } + + default boolean pushDownNested(LogicalNested nested) { + return false; + } + + OpenSearchRequestBuilder build(); +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index e3c9291ced4..b378fae2970 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -29,7 +29,6 @@ import com.google.common.io.Resources; import java.io.IOException; import java.net.URL; -import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -97,14 +96,11 @@ class OpenSearchNodeClientTest { @Mock private SearchHit searchHit; - @Mock - private ThreadContext threadContext; - @Mock private GetIndexResponse indexResponse; - private ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id", - new ExprIntegerValue(1))); + private final ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap( + Map.of("id", new ExprIntegerValue(1))); private OpenSearchClient client; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index 22d02d1ab5d..2958fa11006 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -78,7 +78,6 @@ class OpenSearchRestClientTest { private static final String TEST_MAPPING_FILE = "mappings/accounts.json"; - @Mock(answer = RETURNS_DEEP_STUBS) private RestHighLevelClient restClient; @@ -93,8 +92,8 @@ class OpenSearchRestClientTest { @Mock private GetIndexResponse getIndexResponse; - private ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id", - new ExprIntegerValue(1))); + private final ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap( + Map.of("id", new ExprIntegerValue(1))); @BeforeEach void setUp() { @@ -362,9 +361,7 @@ void scroll_with_IOException() throws IOException { void schedule() { AtomicBoolean isRun = new AtomicBoolean(false); client.schedule( - () -> { - isRun.set(true); - }); + () -> isRun.set(true)); assertTrue(isRun.get()); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java index c96782abea4..330793a5d65 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java @@ -17,12 +17,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.common.setting.Settings.Key.QUERY_SIZE_LIMIT; import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.executor.ExecutionEngine.QueryResponse; -import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.ArrayList; @@ -50,6 +48,8 @@ import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -96,17 +96,17 @@ void execute_successfully() { List actual = new ArrayList<>(); executor.execute( plan, - new ResponseListener() { - @Override - public void onResponse(QueryResponse response) { - actual.addAll(response.getResults()); - } - - @Override - public void onFailure(Exception e) { - fail("Error occurred during execution", e); - } - }); + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + actual.addAll(response.getResults()); + } + + @Override + public void onFailure(Exception e) { + fail("Error occurred during execution", e); + } + }); assertTrue(plan.hasOpen); assertEquals(expected, actual); @@ -126,18 +126,18 @@ void execute_with_cursor() { List actual = new ArrayList<>(); executor.execute( plan, - new ResponseListener() { - @Override - public void onResponse(QueryResponse response) { - actual.addAll(response.getResults()); - assertTrue(response.getCursor().toString().startsWith("n:")); - } - - @Override - public void onFailure(Exception e) { - fail("Error occurred during execution", e); - } - }); + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + actual.addAll(response.getResults()); + assertTrue(response.getCursor().toString().startsWith("n:")); + } + + @Override + public void onFailure(Exception e) { + fail("Error occurred during execution", e); + } + }); assertEquals(expected, actual); } @@ -154,17 +154,17 @@ void execute_with_failure() { AtomicReference actual = new AtomicReference<>(); executor.execute( plan, - new ResponseListener() { - @Override - public void onResponse(QueryResponse response) { - fail("Expected error didn't happen"); - } - - @Override - public void onFailure(Exception e) { - actual.set(e); - } - }); + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + fail("Expected error didn't happen"); + } + + @Override + public void onFailure(Exception e) { + actual.set(e); + } + }); assertEquals(expected, actual.get()); verify(plan).close(); } @@ -174,15 +174,20 @@ void explain_successfully() { OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, new PlanSerializer(null)); Settings settings = mock(Settings.class); - when(settings.getSettingValue(QUERY_SIZE_LIMIT)).thenReturn(100); when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)) .thenReturn(TimeValue.timeValueMinutes(1)); - PhysicalPlan plan = new OpenSearchIndexScan(mock(OpenSearchClient.class), settings, - "test", 10000, mock(OpenSearchExprValueFactory.class)); + OpenSearchExprValueFactory exprValueFactory = mock(OpenSearchExprValueFactory.class); + final var name = new OpenSearchRequest.IndexName("test"); + final int defaultQuerySize = 100; + final int maxResultWindow = 10000; + final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); + PhysicalPlan plan = new OpenSearchIndexScan(mock(OpenSearchClient.class), + maxResultWindow, requestBuilder.build(name, maxResultWindow, + settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE))); AtomicReference result = new AtomicReference<>(); - executor.explain(plan, new ResponseListener() { + executor.explain(plan, new ResponseListener<>() { @Override public void onResponse(ExplainResponse response) { result.set(response); @@ -205,7 +210,7 @@ void explain_with_failure() { when(plan.accept(any(), any())).thenThrow(IllegalStateException.class); AtomicReference result = new AtomicReference<>(); - executor.explain(plan, new ResponseListener() { + executor.explain(plan, new ResponseListener<>() { @Override public void onResponse(ExplainResponse response) { fail("Should fail as expected"); @@ -261,11 +266,11 @@ private static class FakePhysicalPlan extends TableScanOperator implements Seria private boolean hasSplit; @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public void readExternal(ObjectInput in) { } @Override - public void writeExternal(ObjectOutput out) throws IOException { + public void writeExternal(ObjectOutput out) { } @Override diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index fe0077914e2..1c978c849ed 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -9,10 +9,7 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; -import static org.opensearch.sql.common.setting.Settings.Key.QUERY_SIZE_LIMIT; -import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -26,7 +23,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,11 +35,11 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -60,6 +56,8 @@ import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.planner.physical.NestedOperator; @@ -89,22 +87,20 @@ public void setup() { } @Test - public void testProtectIndexScan() { - when(settings.getSettingValue(QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); + void testProtectIndexScan() { String indexName = "test"; - Integer maxResultWindow = 10000; + final int maxResultWindow = 10000; + final int querySizeLimit = 200; NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); ReferenceExpression dedupeField = ref("name", STRING); ReferenceExpression topField = ref("name", STRING); - List topExprs = Arrays.asList(ref("age", INTEGER)); + List topExprs = List.of(ref("age", INTEGER)); Expression filterExpr = literal(ExprBooleanValue.of(true)); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); + List groupByExprs = List.of(named("age", ref("age", INTEGER))); List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); + List.of(named("avg(age)", new AvgAggregator(List.of(ref("age", INTEGER)), + DOUBLE))); Map mappings = ImmutableMap.of(ref("name", STRING), ref("lastname", STRING)); Pair newEvalField = @@ -114,6 +110,10 @@ public void testProtectIndexScan() { Integer limit = 10; Integer offset = 10; + final var name = new OpenSearchRequest.IndexName(indexName); + final var request = new OpenSearchRequestBuilder(querySizeLimit, exprValueFactory) + .build(name, maxResultWindow, + settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.limit( @@ -127,10 +127,8 @@ public void testProtectIndexScan() { PhysicalPlanDSL.agg( filter( resourceMonitor( - new OpenSearchIndexScan(client, settings, - indexName, - maxResultWindow, - exprValueFactory)), + new OpenSearchIndexScan(client, + maxResultWindow, request)), filterExpr), aggregators, groupByExprs), @@ -156,10 +154,8 @@ public void testProtectIndexScan() { PhysicalPlanDSL.rename( PhysicalPlanDSL.agg( filter( - new OpenSearchIndexScan(client, settings, - indexName, - maxResultWindow, - exprValueFactory), + new OpenSearchIndexScan(client, + maxResultWindow, request), filterExpr), aggregators, groupByExprs), @@ -178,7 +174,7 @@ public void testProtectIndexScan() { @SuppressWarnings("unchecked") @Test - public void testProtectSortForWindowOperator() { + void testProtectSortForWindowOperator() { NamedExpression rank = named(mock(RankFunction.class)); Pair sortItem = ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)); @@ -204,7 +200,7 @@ public void testProtectSortForWindowOperator() { } @Test - public void testProtectWindowOperatorInput() { + void testProtectWindowOperatorInput() { NamedExpression avg = named(mock(AggregateWindowFunction.class)); WindowDefinition windowDefinition = mock(WindowDefinition.class); @@ -223,7 +219,7 @@ public void testProtectWindowOperatorInput() { @SuppressWarnings("unchecked") @Test - public void testNotProtectWindowOperatorInputIfAlreadyProtected() { + void testNotProtectWindowOperatorInputIfAlreadyProtected() { NamedExpression avg = named(mock(AggregateWindowFunction.class)); Pair sortItem = ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)); @@ -248,7 +244,7 @@ public void testNotProtectWindowOperatorInputIfAlreadyProtected() { } @Test - public void testWithoutProtection() { + void testWithoutProtection() { Expression filterExpr = literal(ExprBooleanValue.of(true)); assertEquals( @@ -264,7 +260,7 @@ public void testWithoutProtection() { } @Test - public void testVisitMlCommons() { + void testVisitMlCommons() { NodeClient nodeClient = mock(NodeClient.class); MLCommonsOperator mlCommonsOperator = new MLCommonsOperator( @@ -282,7 +278,7 @@ public void testVisitMlCommons() { } @Test - public void testVisitAD() { + void testVisitAD() { NodeClient nodeClient = mock(NodeClient.class); ADOperator adOperator = new ADOperator( @@ -300,7 +296,7 @@ public void testVisitAD() { } @Test - public void testVisitML() { + void testVisitML() { NodeClient nodeClient = mock(NodeClient.class); MLOperator mlOperator = new MLOperator( @@ -320,7 +316,7 @@ public void testVisitML() { } @Test - public void testVisitNested() { + void testVisitNested() { Set args = Set.of("message.info"); Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java deleted file mode 100644 index 5cabe1930d0..00000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; - -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -@ExtendWith(MockitoExtension.class) -public class ContinuePageRequestBuilderTest { - - @Mock - private OpenSearchExprValueFactory exprValueFactory; - - @Mock - private Settings settings; - - private final OpenSearchRequest.IndexName indexName = new OpenSearchRequest.IndexName("test"); - private final String scrollId = "scroll"; - - private ContinuePageRequestBuilder requestBuilder; - - @BeforeEach - void setup() { - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); - requestBuilder = new ContinuePageRequestBuilder( - indexName, scrollId, settings, exprValueFactory); - } - - @Test - public void build() { - assertEquals( - new ContinuePageRequest(scrollId, TimeValue.timeValueMinutes(1), exprValueFactory), - requestBuilder.build() - ); - } - - @Test - public void getIndexName() { - assertEquals(indexName, requestBuilder.getIndexName()); - } - - @Test - public void pushDown_not_supported() { - assertAll( - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownFilter(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownAggregation(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownSort(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownLimit(1, 2)), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownHighlight("", Map.of())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownProjects(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushTypeMapping(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownNested(List.of())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownTrackedScore(true)) - ); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java deleted file mode 100644 index e991fc5787d..00000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.function.Consumer; -import java.util.function.Function; -import lombok.SneakyThrows; -import org.apache.commons.lang3.reflect.FieldUtils; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; - -@ExtendWith(MockitoExtension.class) -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class ContinuePageRequestTest { - - @Mock - private Function searchAction; - - @Mock - private Function scrollAction; - - @Mock - private Consumer cleanAction; - - @Mock - private SearchResponse searchResponse; - - @Mock - private SearchHits searchHits; - - @Mock - private SearchHit searchHit; - - @Mock - private OpenSearchExprValueFactory factory; - - private final String scroll = "scroll"; - private final String nextScroll = "nextScroll"; - - private final ContinuePageRequest request = new ContinuePageRequest( - scroll, TimeValue.timeValueMinutes(1), factory); - - @Test - public void search_with_non_empty_response() { - when(scrollAction.apply(any())).thenReturn(searchResponse); - when(searchResponse.getHits()).thenReturn(searchHits); - when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); - when(searchResponse.getScrollId()).thenReturn(nextScroll); - - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - assertAll( - () -> assertFalse(searchResponse.isEmpty()), - () -> assertEquals(nextScroll, request.toCursor()), - () -> verify(scrollAction, times(1)).apply(any()), - () -> verify(searchAction, never()).apply(any()) - ); - } - - @Test - // Empty response means scroll search is done and no cursor/scroll should be set - public void search_with_empty_response() { - when(scrollAction.apply(any())).thenReturn(searchResponse); - when(searchResponse.getHits()).thenReturn(searchHits); - when(searchHits.getHits()).thenReturn(null); - lenient().when(searchResponse.getScrollId()).thenReturn(nextScroll); - - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - assertAll( - () -> assertTrue(searchResponse.isEmpty()), - () -> assertNull(request.toCursor()), - () -> verify(scrollAction, times(1)).apply(any()), - () -> verify(searchAction, never()).apply(any()) - ); - } - - @Test - @SneakyThrows - public void clean() { - request.clean(cleanAction); - verify(cleanAction, never()).accept(any()); - // Enforce cleaning by setting a private field. - FieldUtils.writeField(request, "scrollFinished", true, true); - request.clean(cleanAction); - verify(cleanAction, times(1)).accept(any()); - } - - @Test - // Added for coverage only - public void getters() { - factory = mock(); - assertAll( - () -> assertThrows(Throwable.class, request::getSourceBuilder), - () -> assertSame(factory, new ContinuePageRequest("", null, factory).getExprValueFactory()) - ); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java deleted file mode 100644 index ef850380d42..00000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.request; - -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; - -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -@ExtendWith(MockitoExtension.class) -public class InitialPageRequestBuilderTest { - - @Mock - private OpenSearchExprValueFactory exprValueFactory; - - @Mock - private Settings settings; - - private final int pageSize = 42; - - private final OpenSearchRequest.IndexName indexName = new OpenSearchRequest.IndexName("test"); - - private InitialPageRequestBuilder requestBuilder; - - @BeforeEach - void setup() { - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); - requestBuilder = new InitialPageRequestBuilder( - indexName, pageSize, settings, exprValueFactory); - } - - @Test - public void build() { - assertEquals( - new OpenSearchScrollRequest(indexName, TimeValue.timeValueMinutes(1), - new SearchSourceBuilder() - .from(0) - .size(pageSize) - .timeout(DEFAULT_QUERY_TIMEOUT), - exprValueFactory), - requestBuilder.build() - ); - } - - @Test - public void pushDown_not_supported() { - assertAll( - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownFilter(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownAggregation(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownSort(mock())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownLimit(1, 2)), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownHighlight("", Map.of())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownNested(List.of())), - () -> assertThrows(UnsupportedOperationException.class, - () -> requestBuilder.pushDownTrackedScore(true)) - ); - } - - @Test - public void pushTypeMapping() { - Map typeMapping = Map.of("intA", OpenSearchDataType.of(INTEGER)); - requestBuilder.pushTypeMapping(typeMapping); - - verify(exprValueFactory).extendTypeMapping(typeMapping); - } - - @Test - public void pushDownProject() { - Set references = Set.of(DSL.ref("intA", INTEGER)); - requestBuilder.pushDownProjects(references); - - assertEquals( - new OpenSearchScrollRequest(indexName, TimeValue.timeValueMinutes(1), - new SearchSourceBuilder() - .from(0) - .size(pageSize) - .timeout(DEFAULT_QUERY_TIMEOUT) - .fetchSource(new String[]{"intA"}, new String[0]), - exprValueFactory), - requestBuilder.build() - ); - } - - @Test - public void getIndexName() { - assertEquals(indexName, requestBuilder.getIndexName()); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index 0aa1f926edd..a92cb44d7a5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -8,17 +8,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; +import static org.opensearch.sql.opensearch.request.OpenSearchRequest.DEFAULT_QUERY_TIMEOUT; -import java.util.Iterator; import java.util.function.Consumer; import java.util.function.Function; +import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -26,12 +28,12 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; -import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -105,10 +107,10 @@ void search_withoutContext() { when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); verify(sourceBuilder, times(1)).fetchSource(); assertFalse(searchResponse.isEmpty()); + assertFalse(request.hasAnotherBatch()); } @Test @@ -146,22 +148,21 @@ void clean() { void searchRequest() { request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); - assertEquals( - new SearchRequest() - .indices("test") - .source(new SearchSourceBuilder() - .timeout(DEFAULT_QUERY_TIMEOUT) - .from(0) - .size(200) - .query(QueryBuilders.termQuery("name", "John"))), - request.searchRequest()); + assertSearchRequest(new SearchRequest() + .indices("test") + .source(new SearchSourceBuilder() + .timeout(DEFAULT_QUERY_TIMEOUT) + .from(0) + .size(200) + .query(QueryBuilders.termQuery("name", "John"))), + request); } @Test void searchCrossClusterRequest() { remoteRequest.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); - assertEquals( + assertSearchRequest( new SearchRequest() .indices("ccs:test") .source(new SearchSourceBuilder() @@ -169,6 +170,23 @@ void searchCrossClusterRequest() { .from(0) .size(200) .query(QueryBuilders.termQuery("name", "John"))), - remoteRequest.searchRequest()); + remoteRequest); + } + + @Test + void writeTo_unsupported() { + assertThrows(UnsupportedOperationException.class, + () -> request.writeTo(mock(StreamOutput.class))); + } + + private void assertSearchRequest(SearchRequest expected, OpenSearchQueryRequest request) { + Function querySearch = searchRequest -> { + assertEquals(expected, searchRequest); + return when(mock(SearchResponse.class).getHits()) + .thenReturn(new SearchHits(new SearchHit[0], + new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f)) + .getMock(); + }; + request.search(querySearch, searchScrollRequest -> null); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 94433c29b96..21618a436d2 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -6,7 +6,9 @@ package org.opensearch.sql.opensearch.request; +import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.index.query.QueryBuilders.matchAllQuery; @@ -20,7 +22,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; @@ -29,11 +33,16 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.InnerHitBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; @@ -42,7 +51,7 @@ import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.ScoreSortBuilder; import org.opensearch.search.sort.SortBuilders; -import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; @@ -55,15 +64,15 @@ @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class OpenSearchRequestBuilderTest { +class OpenSearchRequestBuilderTest { private static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); private static final Integer DEFAULT_OFFSET = 0; private static final Integer DEFAULT_LIMIT = 200; private static final Integer MAX_RESULT_WINDOW = 500; - @Mock - private Settings settings; + private static final OpenSearchRequest.IndexName indexName + = new OpenSearchRequest.IndexName("test"); @Mock private OpenSearchExprValueFactory exprValueFactory; @@ -72,12 +81,7 @@ public class OpenSearchRequestBuilderTest { @BeforeEach void setup() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); - - requestBuilder = new OpenSearchRequestBuilder( - "test", MAX_RESULT_WINDOW, settings, exprValueFactory); + requestBuilder = new OpenSearchRequestBuilder(DEFAULT_LIMIT, exprValueFactory); } @Test @@ -96,7 +100,7 @@ void build_query_request() { .timeout(DEFAULT_QUERY_TIMEOUT) .trackScores(true), exprValueFactory), - requestBuilder.build()); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); } @Test @@ -113,7 +117,7 @@ void build_scroll_request_with_correct_size() { .size(MAX_RESULT_WINDOW - offset) .timeout(DEFAULT_QUERY_TIMEOUT), exprValueFactory), - requestBuilder.build()); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); } @Test @@ -121,15 +125,24 @@ void test_push_down_query() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); requestBuilder.pushDownFilter(query); - assertEquals( + var r = requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT); + Function querySearch = searchRequest -> { + assertEquals( new SearchSourceBuilder() - .from(DEFAULT_OFFSET) - .size(DEFAULT_LIMIT) - .timeout(DEFAULT_QUERY_TIMEOUT) - .query(query) - .sort(DOC_FIELD_NAME, ASC), - requestBuilder.getSourceBuilder() - ); + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(DOC_FIELD_NAME, ASC), + searchRequest.source() + ); + return mock(); + }; + Function scrollSearch = searchScrollRequest -> { + throw new UnsupportedOperationException(); + }; + r.search(querySearch, scrollSearch); + } @Test @@ -161,14 +174,31 @@ void test_push_down_query_and_sort() { FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .query(query) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); + } + + void assertSearchSourceBuilder(SearchSourceBuilder expected, + OpenSearchRequestBuilder requestBuilder) + throws UnsupportedOperationException { + Function querySearch = searchRequest -> { + assertEquals(expected, searchRequest.source()); + return when(mock(SearchResponse.class).getHits()) + .thenReturn(new SearchHits(new SearchHit[0], new TotalHits(0, + TotalHits.Relation.EQUAL_TO), 0.0f)) + .getMock(); + }; + Function scrollSearch = searchScrollRequest -> { + throw new UnsupportedOperationException(); + }; + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT).search( + querySearch, scrollSearch); } @Test @@ -176,13 +206,13 @@ void test_push_down_sort() { FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -190,13 +220,13 @@ void test_push_down_non_field_sort() { ScoreSortBuilder sortBuilder = SortBuilders.scoreSort(); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -205,14 +235,14 @@ void test_push_down_multiple_sort() { SortBuilders.fieldSort("intA"), SortBuilders.fieldSort("intB"))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(SortBuilders.fieldSort("intA")) .sort(SortBuilders.fieldSort("intB")), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -220,13 +250,13 @@ void test_push_down_project() { Set references = Set.of(DSL.ref("intA", INTEGER)); requestBuilder.pushDownProjects(references); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .fetchSource(new String[]{"intA"}, new String[0]), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -250,13 +280,13 @@ void test_push_down_nested() { .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -283,13 +313,13 @@ void test_push_down_multiple_nested_with_same_path() { NestedQueryBuilder nestedQuery = nestedQuery("message", matchAllQuery(), ScoreMode.None) .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info", "message.from"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -314,7 +344,7 @@ void test_push_down_nested_with_filter() { .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query( QueryBuilders.boolQuery().filter( @@ -326,7 +356,7 @@ void test_push_down_nested_with_filter() { .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -336,4 +366,43 @@ void test_push_type_mapping() { verify(exprValueFactory).extendTypeMapping(typeMapping); } + + @Test + void push_down_highlight_with_repeating_fields() { + requestBuilder.pushDownHighlight("name", Map.of()); + var exception = assertThrows(SemanticCheckException.class, () -> + requestBuilder.pushDownHighlight("name", Map.of())); + assertEquals("Duplicate field name in highlight", exception.getMessage()); + } + + @Test + void push_down_page_size() { + requestBuilder.pushDownPageSize(3); + assertSearchSourceBuilder( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(3) + .timeout(DEFAULT_QUERY_TIMEOUT), + requestBuilder); + } + + @Test + void exception_when_non_zero_offset_and_page_size() { + requestBuilder.pushDownPageSize(3); + requestBuilder.pushDownLimit(300, 2); + assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + } + + @Test + void maxResponseSize_is_page_size() { + requestBuilder.pushDownPageSize(4); + assertEquals(4, requestBuilder.getMaxResponseSize()); + } + + @Test + void maxResponseSize_is_limit() { + requestBuilder.pushDownLimit(100, 0); + assertEquals(100, requestBuilder.getMaxResponseSize()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java deleted file mode 100644 index d0a274ce2a2..00000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.request; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.CALLS_REAL_METHODS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.withSettings; - -import org.junit.jupiter.api.Test; - -public class OpenSearchRequestTest { - - @Test - void toCursor() { - var request = mock(OpenSearchRequest.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)); - assertEquals("", request.toCursor()); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java index 461184e6d51..a2585620aa7 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java @@ -8,17 +8,21 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.opensearch.request.OpenSearchScrollRequest.NO_SCROLL_ID; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import lombok.SneakyThrows; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; @@ -29,19 +33,25 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchScrollRequestTest { + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); + public static final TimeValue SCROLL_TIMEOUT = TimeValue.timeValueMinutes(1); @Mock private Function searchAction; @@ -60,25 +70,44 @@ class OpenSearchScrollRequestTest { @Mock private SearchSourceBuilder sourceBuilder; - @Mock - private FetchSourceContext fetchSourceContext; @Mock private OpenSearchExprValueFactory factory; + private final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); private final OpenSearchScrollRequest request = new OpenSearchScrollRequest( - new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), - new SearchSourceBuilder(), factory); + INDEX_NAME, SCROLL_TIMEOUT, + searchSourceBuilder, factory); @Test - void searchRequest() { - request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); + void constructor() { + searchSourceBuilder.fetchSource(new String[] {"test"}, null); + var request = new OpenSearchScrollRequest(INDEX_NAME, SCROLL_TIMEOUT, + searchSourceBuilder, factory); + assertNotEquals(List.of(), request.getIncludes()); + } - assertEquals( + @Test + void constructor2() { + searchSourceBuilder.fetchSource(new String[]{"test"}, null); + var request = new OpenSearchScrollRequest(INDEX_NAME, SCROLL_TIMEOUT, searchSourceBuilder, + factory); + assertNotEquals(List.of(), request.getIncludes()); + } + + @Test + void searchRequest() { + searchSourceBuilder.query(QueryBuilders.termQuery("name", "John")); + request.search(searchRequest -> { + assertEquals( new SearchRequest() - .indices("test") - .scroll(TimeValue.timeValueMinutes(1)) - .source(new SearchSourceBuilder().query(QueryBuilders.termQuery("name", "John"))), - request.searchRequest()); + .indices("test") + .scroll(TimeValue.timeValueMinutes(1)) + .source(new SearchSourceBuilder().query(QueryBuilders.termQuery("name", "John"))), + searchRequest); + SearchHits searchHitsMock = when(mock(SearchHits.class).getHits()) + .thenReturn(new SearchHit[0]).getMock(); + return when(mock(SearchResponse.class).getHits()).thenReturn(searchHitsMock).getMock(); + }, searchScrollRequest -> null); } @Test @@ -111,16 +140,16 @@ void search() { factory ); - String[] includes = {"_id", "_index"}; - when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); - when(fetchSourceContext.includes()).thenReturn(includes); - when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - verify(fetchSourceContext, times(2)).includes(); - assertFalse(searchResponse.isEmpty()); + Function scrollSearch = searchScrollRequest -> { + throw new AssertionError(); + }; + OpenSearchResponse openSearchResponse = request.search(searchRequest -> searchResponse, + scrollSearch); + + assertFalse(openSearchResponse.isEmpty()); } @Test @@ -132,7 +161,6 @@ void search_withoutContext() { factory ); - when(sourceBuilder.fetchSource()).thenReturn(null); when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); @@ -151,24 +179,24 @@ void search_withoutIncludes() { factory ); - when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); - when(fetchSourceContext.includes()).thenReturn(null); when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - verify(fetchSourceContext, times(1)).includes(); assertFalse(searchResponse.isEmpty()); } @Test - void toCursor() { + void hasAnotherBatch() { request.setScrollId("scroll123"); - assertEquals("scroll123", request.toCursor()); + assertTrue(request.hasAnotherBatch()); request.reset(); - assertNull(request.toCursor()); + assertFalse(request.hasAnotherBatch()); + + request.setScrollId(""); + assertFalse(request.hasAnotherBatch()); } @Test @@ -188,7 +216,7 @@ void clean_on_empty_response() { AtomicBoolean cleanCalled = new AtomicBoolean(false); request.clean((s) -> cleanCalled.set(true)); - assertNull(request.getScrollId()); + assertEquals(NO_SCROLL_ID, request.getScrollId()); assertTrue(cleanCalled.get()); } @@ -203,7 +231,17 @@ void no_clean_on_non_empty_response() { assertEquals("scroll", request.getScrollId()); request.clean((s) -> fail()); - assertNull(request.getScrollId()); + assertEquals(NO_SCROLL_ID, request.getScrollId()); + } + + @Test + void no_cursor_on_empty_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], null, 1f)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertFalse(request.hasAnotherBatch()); } @Test @@ -213,8 +251,73 @@ void no_clean_if_no_scroll_in_response() { new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); request.search((x) -> searchResponse, (x) -> searchResponse); - assertNull(request.getScrollId()); + assertEquals(NO_SCROLL_ID, request.getScrollId()); request.clean((s) -> fail()); } + + @Test + @SneakyThrows + void serialize_deserialize_no_needClean() { + var stream = new BytesStreamOutput(); + request.writeTo(stream); + stream.flush(); + assertTrue(stream.size() > 0); + + // deserialize + var inStream = new BytesStreamInput(stream.bytes().toBytesRef().bytes); + var indexMock = mock(OpenSearchIndex.class); + var engine = mock(OpenSearchStorageEngine.class); + when(engine.getTable(any(), any())).thenReturn(indexMock); + var newRequest = new OpenSearchScrollRequest(inStream, engine); + assertEquals(request.getInitialSearchRequest(), newRequest.getInitialSearchRequest()); + assertEquals("", newRequest.getScrollId()); + } + + @Test + @SneakyThrows + void serialize_deserialize_needClean() { + lenient().when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + lenient().when(searchResponse.getScrollId()).thenReturn(""); + + var stream = new BytesStreamOutput(); + request.search(searchRequest -> searchResponse, null); + request.writeTo(stream); + stream.flush(); + assertTrue(stream.size() > 0); + + // deserialize + var inStream = new BytesStreamInput(stream.bytes().toBytesRef().bytes); + var indexMock = mock(OpenSearchIndex.class); + var engine = mock(OpenSearchStorageEngine.class); + when(engine.getTable(any(), any())).thenReturn(indexMock); + var newRequest = new OpenSearchScrollRequest(inStream, engine); + assertEquals(request.getInitialSearchRequest(), newRequest.getInitialSearchRequest()); + assertEquals("", newRequest.getScrollId()); + } + + @Test + void setScrollId() { + request.setScrollId("test"); + assertEquals("test", request.getScrollId()); + } + + @Test + void includes() { + + assertIncludes(List.of(), searchSourceBuilder); + + searchSourceBuilder.fetchSource((String[])null, (String[])null); + assertIncludes(List.of(), searchSourceBuilder); + + searchSourceBuilder.fetchSource(new String[] {"test"}, null); + assertIncludes(List.of("test"), searchSourceBuilder); + + } + + void assertIncludes(List expected, SearchSourceBuilder sourceBuilder) { + assertEquals(expected, new OpenSearchScrollRequest( + INDEX_NAME, SCROLL_TIMEOUT, sourceBuilder, factory).getIncludes()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 2ff1de862b1..11694813cc9 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.opensearch.storage; import static org.hamcrest.MatcherAssert.assertThat; @@ -12,14 +11,13 @@ import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType; @@ -30,9 +28,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; import com.google.common.collect.ImmutableMap; -import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -45,34 +41,30 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.aggregation.AvgAggregator; -import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; -import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; -import org.opensearch.sql.opensearch.request.PagedRequestBuilder; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; -import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScan; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; -import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) class OpenSearchIndexTest { - private final String indexName = "test"; + public static final int QUERY_SIZE_LIMIT = 200; + public static final TimeValue SCROLL_TIMEOUT = new TimeValue(1); + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); @Mock private OpenSearchClient client; @@ -83,9 +75,6 @@ class OpenSearchIndexTest { @Mock private Settings settings; - @Mock - private Table table; - @Mock private IndexMapping mapping; @@ -93,30 +82,31 @@ class OpenSearchIndexTest { @BeforeEach void setUp() { - this.index = new OpenSearchIndex(client, settings, indexName); + this.index = new OpenSearchIndex(client, settings, "test"); } @Test void isExist() { - when(client.exists(indexName)).thenReturn(true); + when(client.exists("test")).thenReturn(true); assertTrue(index.exists()); } @Test void createIndex() { - Map mappings = ImmutableMap.of( + Map mappings = Map.of( "properties", - ImmutableMap.of( + Map.of( "name", "text", "age", "integer")); - doNothing().when(client).createIndex(indexName, mappings); + doNothing().when(client).createIndex("test", mappings); Map schema = new HashMap<>(); schema.put("name", OpenSearchTextType.of(Map.of("keyword", OpenSearchDataType.of(MappingType.Keyword)))); schema.put("age", INTEGER); index.create(schema); + verify(client).createIndex(any(), any()); } @Test @@ -137,7 +127,7 @@ void getFieldTypes() { .put("id2", MappingType.Short) .put("blob", MappingType.Binary) .build().entrySet().stream().collect(Collectors.toMap( - e -> e.getKey(), e -> OpenSearchDataType.of(e.getValue()) + Map.Entry::getKey, e -> OpenSearchDataType.of(e.getValue()) ))); when(client.getIndexMappings("test")).thenReturn(ImmutableMap.of("test", mapping)); @@ -208,64 +198,38 @@ void getReservedFieldTypes() { @Test void implementRelationOperatorOnly() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals(new OpenSearchIndexScan(client, settings, indexName, - maxResultWindow, exprValueFactory), index.implement(index.optimize(plan))); - } - - @Test - void implementPagedRelationOperatorOnly() { - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); - - LogicalPlan plan = index.createPagedScanBuilder(42); - Integer maxResultWindow = index.getMaxResultWindow(); - PagedRequestBuilder builder = new InitialPageRequestBuilder( - new OpenSearchRequest.IndexName(indexName), - maxResultWindow, mock(), exprValueFactory); - assertEquals(new OpenSearchPagedIndexScan(client, builder), index.implement(plan)); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + assertEquals(new OpenSearchIndexScan(client, + 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + index.implement(index.optimize(plan))); } @Test void implementRelationOperatorWithOptimization() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals(new OpenSearchIndexScan(client, settings, indexName, - maxResultWindow, exprValueFactory), index.implement(plan)); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + assertEquals(new OpenSearchIndexScan(client, 200, + requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), index.implement(plan)); } @Test void implementOtherLogicalOperators() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); ReferenceExpression dedupeField = ref("name", STRING); - Expression filterExpr = literal(ExprBooleanValue.of(true)); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); Map mappings = ImmutableMap.of(ref("name", STRING), ref("lastname", STRING)); Pair newEvalField = ImmutablePair.of(ref("name1", STRING), ref("name", STRING)); - Integer sortCount = 100; Pair sortField = ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("name1", STRING)); @@ -285,6 +249,7 @@ void implementOtherLogicalOperators() { include); Integer maxResultWindow = index.getMaxResultWindow(); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.dedupe( @@ -292,8 +257,9 @@ void implementOtherLogicalOperators() { PhysicalPlanDSL.eval( PhysicalPlanDSL.remove( PhysicalPlanDSL.rename( - new OpenSearchIndexScan(client, settings, indexName, - maxResultWindow, exprValueFactory), + new OpenSearchIndexScan(client, + QUERY_SIZE_LIMIT, requestBuilder.build(INDEX_NAME, maxResultWindow, + SCROLL_TIMEOUT)), mappings), exclude), newEvalField), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java new file mode 100644 index 00000000000..5a510fefec1 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +@ExtendWith(MockitoExtension.class) +class OpenSearchIndexScanAggregationBuilderTest { + @Mock + OpenSearchRequestBuilder requestBuilder; + @Mock + LogicalAggregation logicalAggregation; + OpenSearchIndexScanAggregationBuilder builder; + + @BeforeEach + void setup() { + builder = new OpenSearchIndexScanAggregationBuilder(requestBuilder, logicalAggregation); + } + + @Test + void pushDownFilter() { + assertFalse(builder.pushDownFilter(mock(LogicalFilter.class))); + } + + @Test + void pushDownSort() { + assertTrue(builder.pushDownSort(mock(LogicalSort.class))); + } + + @Test + void pushDownLimit() { + assertFalse(builder.pushDownLimit(mock(LogicalLimit.class))); + } + + @Test + void pushDownProject() { + assertFalse(builder.pushDownProject(mock(LogicalProject.class))); + } + + @Test + void pushDownHighlight() { + assertFalse(builder.pushDownHighlight(mock(LogicalHighlight.class))); + } + + @Test + void pushDownPageSize() { + assertFalse(builder.pushDownPageSize(mock(LogicalPaginate.class))); + } + + @Test + void pushDownNested() { + assertFalse(builder.pushDownNested(mock(LogicalNested.class))); + } + +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index bde940a939e..6bf9002a673 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -7,7 +7,8 @@ package org.opensearch.sql.opensearch.storage.scan; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.eq; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -24,6 +25,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; @@ -53,7 +55,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.SpanOrQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; @@ -65,7 +66,6 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.FunctionExpression; @@ -79,13 +79,14 @@ import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; -import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalNested; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.planner.optimizer.PushDownPageSize; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; import org.opensearch.sql.storage.Table; - +import org.opensearch.sql.storage.TableScanOperator; @ExtendWith(MockitoExtension.class) class OpenSearchIndexScanOptimizationTest { @@ -105,16 +106,20 @@ class OpenSearchIndexScanOptimizationTest { @BeforeEach void setUp() { - indexScanBuilder = new OpenSearchIndexScanBuilder(indexScan); + indexScanBuilder = new OpenSearchIndexScanBuilder(requestBuilder) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; when(table.createScanBuilder()).thenReturn(indexScanBuilder); - when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); } @Test void test_project_push_down() { assertEqualsAfterOptimization( project( - indexScanAggBuilder( + indexScanBuilder( withProjectPushedDown(DSL.ref("intV", INTEGER))), DSL.named("i", DSL.ref("intV", INTEGER)) ), @@ -336,6 +341,21 @@ void test_sort_push_down() { ); } + @Test + void test_page_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withPageSizePushDown(5)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + paginate(project( + relation("schema", table), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), 5 + )); + } + @Test void test_score_sort_push_down() { assertEqualsAfterOptimization( @@ -678,16 +698,28 @@ void project_literal_should_not_be_pushed_down() { private OpenSearchIndexScanBuilder indexScanBuilder(Runnable... verifyPushDownCalls) { this.verifyPushDownCalls = verifyPushDownCalls; - return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(indexScan)); + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(requestBuilder)) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; } private OpenSearchIndexScanBuilder indexScanAggBuilder(Runnable... verifyPushDownCalls) { this.verifyPushDownCalls = verifyPushDownCalls; - return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder(indexScan)); + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder( + requestBuilder, mock(LogicalAggregation.class))) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; } private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan actual) { - assertEquals(expected, optimize(actual)); + final var optimized = optimize(actual); + assertEquals(expected, optimized); // Trigger build to make sure all push down actually happened in scan builder indexScanBuilder.build(); @@ -759,6 +791,10 @@ private Runnable withTrackedScoresPushedDown(boolean trackScores) { return () -> verify(requestBuilder, times(1)).pushDownTrackedScore(trackScores); } + private Runnable withPageSizePushDown(int pageSize) { + return () -> verify(requestBuilder, times(1)).pushDownPageSize(pageSize); + } + private static AggregationAssertHelper.AggregationAssertHelperBuilder aggregate(String aggName) { var aggBuilder = new AggregationAssertHelper.AggregationAssertHelperBuilder(); aggBuilder.aggregateName = aggName; @@ -784,6 +820,7 @@ private static class AggregationAssertHelper { private LogicalPlan optimize(LogicalPlan plan) { LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer(List.of( new CreateTableScanBuilder(), + new PushDownPageSize(), PUSH_DOWN_FILTER, PUSH_DOWN_AGGREGATION, PUSH_DOWN_SORT, diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java new file mode 100644 index 00000000000..67f0869d6eb --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.QUERY_SIZE; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.mockResponse; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class OpenSearchIndexScanPaginationTest { + + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); + public static final int MAX_RESULT_WINDOW = 3; + public static final TimeValue SCROLL_TIMEOUT = TimeValue.timeValueMinutes(4); + @Mock + private Settings settings; + + @BeforeEach + void setup() { + lenient().when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(QUERY_SIZE); + lenient().when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + } + + @Mock + private OpenSearchClient client; + + private final OpenSearchExprValueFactory exprValueFactory + = new OpenSearchExprValueFactory(Map.of( + "name", OpenSearchDataType.of(STRING), + "department", OpenSearchDataType.of(STRING))); + + @Test + void query_empty_result() { + mockResponse(client); + var builder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (var indexScan = new OpenSearchIndexScan(client, MAX_RESULT_WINDOW, + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + indexScan.open(); + assertFalse(indexScan.hasNext()); + } + verify(client).cleanup(any()); + } + + @Test + void explain_not_implemented() { + assertThrows(Throwable.class, () -> mock(OpenSearchIndexScan.class, + withSettings().defaultAnswer(CALLS_REAL_METHODS)).explain()); + } + + @Test + @SneakyThrows + void dont_serialize_if_no_cursor() { + OpenSearchRequestBuilder builder = mock(); + OpenSearchRequest request = mock(); + OpenSearchResponse response = mock(); + when(builder.build(any(), anyInt(), any())).thenReturn(request); + when(client.search(any())).thenReturn(response); + try (var indexScan + = new OpenSearchIndexScan(client, MAX_RESULT_WINDOW, + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + indexScan.open(); + + when(request.hasAnotherBatch()).thenReturn(false); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + assertThrows(NoCursorException.class, () -> objectOutput.writeObject(indexScan)); + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java index c788e78f1a4..e9747906299 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.lenient; @@ -19,9 +20,12 @@ import static org.opensearch.search.sort.SortOrder.ASC; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; @@ -36,47 +40,105 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchIndexScanTest { + public static final int QUERY_SIZE = 200; + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("employees"); + public static final int MAX_RESULT_WINDOW = 10000; + public static final TimeValue CURSOR_KEEP_ALIVE = TimeValue.timeValueMinutes(1); @Mock private OpenSearchClient client; - @Mock - private Settings settings; - - private OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( + private final OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( Map.of("name", OpenSearchDataType.of(STRING), "department", OpenSearchDataType.of(STRING))); @BeforeEach void setup() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) - .thenReturn(TimeValue.timeValueMinutes(1)); + } + + @Test + void explain() { + var request = mock(OpenSearchRequest.class); + when(request.toString()).thenReturn("explain works!"); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + assertEquals("explain works!", indexScan.explain()); + } + } + + @Test + @SneakyThrows + void throws_no_cursor_exception() { + var request = mock(OpenSearchRequest.class); + when(request.hasAnotherBatch()).thenReturn(false); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request); + var byteStream = new ByteArrayOutputStream(); + var objectStream = new ObjectOutputStream(byteStream)) { + assertThrows(NoCursorException.class, () -> objectStream.writeObject(indexScan)); + } + } + + @Test + @SneakyThrows + void serialize() { + var searchSourceBuilder = new SearchSourceBuilder().size(4); + + var factory = mock(OpenSearchExprValueFactory.class); + var engine = mock(OpenSearchStorageEngine.class); + var index = mock(OpenSearchIndex.class); + when(engine.getClient()).thenReturn(client); + when(engine.getTable(any(), any())).thenReturn(index); + var request = new OpenSearchScrollRequest( + INDEX_NAME, CURSOR_KEEP_ALIVE, searchSourceBuilder, factory); + request.setScrollId("valid-id"); + + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + var planSerializer = new PlanSerializer(engine); + var cursor = planSerializer.convertToCursor(indexScan); + var newPlan = planSerializer.convertToPlan(cursor.toString()); + assertEquals(indexScan, newPlan); + } + + } + + @Test + void plan_for_serialization() { + var request = mock(OpenSearchRequest.class); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + assertEquals(indexScan, indexScan.getPlanForSerialization()); + } } @Test void query_empty_result() { mockResponse(client); - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "test", 3, exprValueFactory)) { + final var name = new OpenSearchRequest.IndexName("test"); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + QUERY_SIZE, requestBuilder.build(name, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( () -> assertFalse(indexScan.hasNext()), @@ -93,8 +155,9 @@ void query_all_results_with_query() { employee(2, "Smith", "HR"), employee(3, "Allen", "IT")}); - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "employees", 10, exprValueFactory)) { + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( @@ -114,14 +177,18 @@ void query_all_results_with_query() { verify(client).cleanup(any()); } + static final OpenSearchRequest.IndexName EMPLOYEES_INDEX + = new OpenSearchRequest.IndexName("employees"); + @Test void query_all_results_with_scroll() { mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT")}); - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "employees", 10, exprValueFactory)) { + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( @@ -149,9 +216,10 @@ void query_some_results_with_query() { employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "employees", 10, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); + final int limit = 3; + OpenSearchRequestBuilder builder = new OpenSearchRequestBuilder(0, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + limit, builder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( @@ -173,13 +241,10 @@ void query_some_results_with_query() { @Test void query_some_results_with_scroll() { - mockResponse(client, - new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, - new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "employees", 2, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); + mockTwoPageResponse(client); + final var requestuilder = new OpenSearchRequestBuilder(10, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 3, requestuilder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( @@ -199,6 +264,12 @@ void query_some_results_with_scroll() { verify(client).cleanup(any()); } + static void mockTwoPageResponse(OpenSearchClient client) { + mockResponse(client, + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); + } + @Test void query_results_limited_by_query_size() { mockResponse(client, new ExprValue[]{ @@ -206,10 +277,11 @@ void query_results_limited_by_query_size() { employee(2, "Smith", "HR"), employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(2); - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "employees", 10, exprValueFactory)) { + final int defaultQuerySize = 2; + final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + defaultQuerySize, requestBuilder.build(INDEX_NAME, QUERY_SIZE, CURSOR_KEEP_ALIVE))) { indexScan.open(); assertAll( @@ -270,73 +342,65 @@ void push_down_highlight_with_arguments() { highlightBuilder); } - @Test - void push_down_highlight_with_repeating_fields() { - mockResponse(client, - new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, - new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - - try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, - "test", 2, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); - indexScan.open(); - Map args = new HashMap<>(); - indexScan.getRequestBuilder().pushDownHighlight("name", args); - indexScan.getRequestBuilder().pushDownHighlight("name", args); - } catch (SemanticCheckException e) { - assertTrue(e.getClass().equals(SemanticCheckException.class)); - } - verify(client).cleanup(any()); - } - private PushDownAssertion assertThat() { - return new PushDownAssertion(client, exprValueFactory, settings); + return new PushDownAssertion(client, exprValueFactory); } private static class PushDownAssertion { private final OpenSearchClient client; - private final OpenSearchIndexScan indexScan; + private final OpenSearchRequestBuilder requestBuilder; private final OpenSearchResponse response; private final OpenSearchExprValueFactory factory; public PushDownAssertion(OpenSearchClient client, - OpenSearchExprValueFactory valueFactory, - Settings settings) { + OpenSearchExprValueFactory valueFactory) { this.client = client; - this.indexScan = new OpenSearchIndexScan(client, settings, - "test", 10000, valueFactory); + this.requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, valueFactory); + this.response = mock(OpenSearchResponse.class); this.factory = valueFactory; when(response.isEmpty()).thenReturn(true); } PushDownAssertion pushDown(QueryBuilder query) { - indexScan.getRequestBuilder().pushDownFilter(query); + requestBuilder.pushDownFilter(query); return this; } PushDownAssertion pushDownHighlight(String query, Map arguments) { - indexScan.getRequestBuilder().pushDownHighlight(query, arguments); + requestBuilder.pushDownHighlight(query, arguments); return this; } PushDownAssertion shouldQueryHighlight(QueryBuilder query, HighlightBuilder highlight) { - OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); - request.getSourceBuilder() + var sourceBuilder = new SearchSourceBuilder() + .from(0) + .timeout(CURSOR_KEEP_ALIVE) .query(query) + .size(QUERY_SIZE) .highlighter(highlight) .sort(DOC_FIELD_NAME, ASC); + OpenSearchRequest request = + new OpenSearchQueryRequest(EMPLOYEES_INDEX, sourceBuilder, factory); + when(client.search(request)).thenReturn(response); + var indexScan = new OpenSearchIndexScan(client, + QUERY_SIZE, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); indexScan.open(); return this; } PushDownAssertion shouldQuery(QueryBuilder expected) { - OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); - request.getSourceBuilder() - .query(expected) - .sort(DOC_FIELD_NAME, ASC); + var builder = new SearchSourceBuilder() + .from(0) + .query(expected) + .size(QUERY_SIZE) + .timeout(CURSOR_KEEP_ALIVE) + .sort(DOC_FIELD_NAME, ASC); + OpenSearchRequest request = new OpenSearchQueryRequest(EMPLOYEES_INDEX, builder, factory); when(client.search(request)).thenReturn(response); + var indexScan = new OpenSearchIndexScan(client, + 10000, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); indexScan.open(); return this; } @@ -356,7 +420,6 @@ public OpenSearchResponse answer(InvocationOnMock invocation) { when(response.isEmpty()).thenReturn(false); ExprValue[] searchHit = searchHitBatches[batchNum]; when(response.iterator()).thenReturn(Arrays.asList(searchHit).iterator()); - // used in OpenSearchPagedIndexScanTest lenient().when(response.getTotalHits()) .thenReturn((long) searchHitBatches[batchNum].length); } else { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java deleted file mode 100644 index cd941540126..00000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.CALLS_REAL_METHODS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.employee; -import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.mockResponse; - -import com.google.common.collect.ImmutableMap; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.util.Map; -import lombok.SneakyThrows; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.NoCursorException; -import org.opensearch.sql.executor.pagination.PlanSerializer; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.request.ContinuePageRequestBuilder; -import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; -import org.opensearch.sql.opensearch.request.OpenSearchRequest; -import org.opensearch.sql.opensearch.request.PagedRequestBuilder; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; -import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; - -@ExtendWith(MockitoExtension.class) -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class OpenSearchPagedIndexScanTest { - @Mock - private OpenSearchClient client; - - private final OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( - ImmutableMap.of( - "name", OpenSearchDataType.of(STRING), - "department", OpenSearchDataType.of(STRING))); - - @Test - void query_empty_result() { - mockResponse(client); - InitialPageRequestBuilder builder = new InitialPageRequestBuilder( - new OpenSearchRequest.IndexName("test"), 3, mock(), exprValueFactory); - try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { - indexScan.open(); - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void query_all_results_initial_scroll_request() { - mockResponse(client, new ExprValue[]{ - employee(1, "John", "IT"), - employee(2, "Smith", "HR"), - employee(3, "Allen", "IT")}); - - PagedRequestBuilder builder = new InitialPageRequestBuilder( - new OpenSearchRequest.IndexName("test"), 3, mock(), exprValueFactory); - try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { - indexScan.open(); - - assertAll( - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - - () -> assertFalse(indexScan.hasNext()), - () -> assertEquals(3, indexScan.getTotalHits()) - ); - } - verify(client).cleanup(any()); - - builder = new ContinuePageRequestBuilder( - new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); - try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { - indexScan.open(); - - assertFalse(indexScan.hasNext()); - } - verify(client, times(2)).cleanup(any()); - } - - @Test - void query_all_results_continuation_scroll_request() { - mockResponse(client, new ExprValue[]{ - employee(1, "John", "IT"), - employee(2, "Smith", "HR"), - employee(3, "Allen", "IT")}); - - ContinuePageRequestBuilder builder = new ContinuePageRequestBuilder( - new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); - try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { - indexScan.open(); - - assertAll( - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - - () -> assertTrue(indexScan.hasNext()), - () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - - () -> assertFalse(indexScan.hasNext()), - () -> assertEquals(3, indexScan.getTotalHits()) - ); - } - verify(client).cleanup(any()); - - builder = new ContinuePageRequestBuilder( - new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); - try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { - indexScan.open(); - - assertFalse(indexScan.hasNext()); - } - verify(client, times(2)).cleanup(any()); - } - - @Test - void explain_not_implemented() { - assertThrows(Throwable.class, () -> mock(OpenSearchPagedIndexScan.class, - withSettings().defaultAnswer(CALLS_REAL_METHODS)).explain()); - } - - @Test - @SneakyThrows - void serialization() { - PagedRequestBuilder builder = mock(); - OpenSearchRequest request = mock(); - OpenSearchResponse response = mock(); - when(request.toCursor()).thenReturn("cu-cursor"); - when(builder.build()).thenReturn(request); - var indexName = new OpenSearchRequest.IndexName("index"); - when(builder.getIndexName()).thenReturn(indexName); - when(client.search(any())).thenReturn(response); - OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder); - indexScan.open(); - - ByteArrayOutputStream output = new ByteArrayOutputStream(); - ObjectOutputStream objectOutput = new ObjectOutputStream(output); - objectOutput.writeObject(indexScan); - objectOutput.flush(); - - when(client.getIndexMappings(any())).thenReturn(Map.of()); - OpenSearchStorageEngine engine = mock(); - when(engine.getClient()).thenReturn(client); - when(engine.getSettings()).thenReturn(mock()); - ObjectInputStream objectInput = new PlanSerializer(engine) - .getCursorDeserializationStream(new ByteArrayInputStream(output.toByteArray())); - var roundTripScan = (OpenSearchPagedIndexScan) objectInput.readObject(); - roundTripScan.open(); - - // indexScan's request could be a OpenSearchScrollRequest or a ContinuePageRequest, but - // roundTripScan's request is always a ContinuePageRequest - // Thus, we can't compare those scans - //assertEquals(indexScan, roundTripScan); - // But we can validate that index name and scroll was serialized-deserialized correctly - assertEquals(indexName, roundTripScan.getRequestBuilder().getIndexName()); - assertTrue(roundTripScan.getRequestBuilder() instanceof ContinuePageRequestBuilder); - assertEquals("cu-cursor", - ((ContinuePageRequestBuilder) roundTripScan.getRequestBuilder()).getScrollId()); - } - - @Test - @SneakyThrows - void dont_serialize_if_no_cursor() { - PagedRequestBuilder builder = mock(); - OpenSearchRequest request = mock(); - OpenSearchResponse response = mock(); - when(builder.build()).thenReturn(request); - when(client.search(any())).thenReturn(response); - OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder); - indexScan.open(); - - when(request.toCursor()).thenReturn(null, ""); - for (int i = 0; i < 2; i++) { - assertThrows(NoCursorException.class, () -> { - ByteArrayOutputStream output = new ByteArrayOutputStream(); - ObjectOutputStream objectOutput = new ObjectOutputStream(output); - objectOutput.writeObject(indexScan); - objectOutput.flush(); - }); - } - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java new file mode 100644 index 00000000000..0b0568a6b7d --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java @@ -0,0 +1,42 @@ +package org.opensearch.sql.opensearch.storage.scan; + + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +@ExtendWith(MockitoExtension.class) +class PushDownQueryBuilderTest { + @Test + void default_implementations() { + var sample = new PushDownQueryBuilder() { + @Override + public OpenSearchRequestBuilder build() { + return null; + } + }; + assertAll( + () -> assertFalse(sample.pushDownFilter(mock(LogicalFilter.class))), + () -> assertFalse(sample.pushDownProject(mock(LogicalProject.class))), + () -> assertFalse(sample.pushDownHighlight(mock(LogicalHighlight.class))), + () -> assertFalse(sample.pushDownSort(mock(LogicalSort.class))), + () -> assertFalse(sample.pushDownNested(mock(LogicalNested.class))), + () -> assertFalse(sample.pushDownLimit(mock(LogicalLimit.class))), + () -> assertFalse(sample.pushDownPageSize(mock(LogicalPaginate.class))) + + ); + } + +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java index b80cb3faab5..f301a242fb4 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java @@ -75,7 +75,7 @@ public ExecutionProtector protector(ResourceMonitor resourceMonitor) { } @Provides - public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { + public PlanSerializer planSerializer(StorageEngine storageEngine) { return new PlanSerializer(storageEngine); } @@ -100,14 +100,13 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl */ @Provides public QueryPlanFactory queryPlanFactory(DataSourceService dataSourceService, - ExecutionEngine executionEngine, - PlanSerializer planSerializer) { + ExecutionEngine executionEngine) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); QueryService queryService = new QueryService( analyzer, executionEngine, planner); - return new QueryPlanFactory(queryService, planSerializer); + return new QueryPlanFactory(queryService); } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java index f91ac7222f5..40a7a85f78b 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java @@ -90,7 +90,7 @@ private AbstractPlan plan( QueryContext.getRequestId(), anonymizer.anonymizeStatement(statement)); - return queryExecutionFactory.createContinuePaginatedPlan( + return queryExecutionFactory.create( statement, queryListener, explainListener); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java index 117aca50bfc..74e5b0f82ea 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java @@ -27,7 +27,6 @@ import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; import org.opensearch.sql.executor.pagination.Cursor; -import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; import org.opensearch.sql.ppl.domain.PPLQueryRequest; @@ -48,9 +47,6 @@ public class PPLServiceTest { @Mock private ExecutionEngine.Schema schema; - @Mock - private PlanSerializer planSerializer; - /** * Setup the test context. */ @@ -59,7 +55,7 @@ public void setUp() { queryManager = DefaultQueryManager.defaultQueryManager(); pplService = new PPLService(new PPLSyntaxParser(), queryManager, - new QueryPlanFactory(queryService, planSerializer)); + new QueryPlanFactory(queryService)); } @After diff --git a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java index 4ecf9e699be..889f80223f4 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java +++ b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java @@ -65,10 +65,15 @@ private AbstractPlan plan( SQLQueryRequest request, Optional> queryListener, Optional> explainListener) { + boolean isExplainRequest = request.isExplainRequest(); if (request.getCursor().isPresent()) { // Handle v2 cursor here -- legacy cursor was handled earlier. - return queryExecutionFactory.createContinuePaginatedPlan(request.getCursor().get(), - request.isExplainRequest(), queryListener.orElse(null), explainListener.orElse(null)); + if (isExplainRequest) { + throw new UnsupportedOperationException("Explain of a paged query continuation " + + "is not supported. Use `explain` for the initial query request."); + } + return queryExecutionFactory.create(request.getCursor().get(), + isExplainRequest, queryListener.orElse(null), explainListener.orElse(null)); } else { // 1.Parse query and convert parse tree (CST) to abstract syntax tree (AST) ParseTree cst = parser.parse(request.getQuery()); @@ -77,11 +82,11 @@ private AbstractPlan plan( new AstStatementBuilder( new AstBuilder(request.getQuery()), AstStatementBuilder.StatementBuilderContext.builder() - .isExplain(request.isExplainRequest()) + .isExplain(isExplainRequest) .fetchSize(request.getFetchSize()) .build())); - return queryExecutionFactory.createContinuePaginatedPlan( + return queryExecutionFactory.create( statement, queryListener, explainListener); } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java b/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java index 39c27c5e069..f34c95e121f 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java @@ -6,8 +6,8 @@ package org.opensearch.sql.sql; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -30,7 +30,6 @@ import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; -import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; import org.opensearch.sql.sql.domain.SQLQueryRequest; @@ -49,14 +48,11 @@ class SQLServiceTest { @Mock private QueryService queryService; - @Mock - private PlanSerializer planSerializer; - @BeforeEach public void setUp() { queryManager = DefaultQueryManager.defaultQueryManager(); sqlService = new SQLService(new SQLSyntaxParser(), queryManager, - new QueryPlanFactory(queryService, planSerializer)); + new QueryPlanFactory(queryService)); } @AfterEach @@ -149,8 +145,8 @@ public void onResponse(ExplainResponse response) { @Override public void onFailure(Exception e) { - assertTrue(e.getMessage() - .contains("`explain` request for cursor requests is not supported.")); + assertEquals("Explain of a paged query continuation is not supported." + + " Use `explain` for the initial query request.", e.getMessage()); } }); }