diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/AggregateUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/AggregateUtils.java index 99321a426fc..af30c47d2e3 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/AggregateUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/AggregateUtils.java @@ -47,6 +47,10 @@ static RelBuilder.AggCall translate( // case STDDEV: // return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV, // field); + case VARSAMP: + return context.relBuilder.aggregateCall(SqlStdOperatorTable.VAR_SAMP, field); + case VARPOP: + return context.relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, field); case STDDEV_POP: return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_POP, field); case STDDEV_SAMP: diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java index e17e0a1cf9b..435a192ae51 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java @@ -178,6 +178,21 @@ public void testApproxCountDistinct() { "source=%s | stats distinct_count_approx(state) by gender", TEST_INDEX_BANK)); } + @Test + public void testVarSampVarPop() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats var_samp(balance) as vs, var_pop(balance) as vp by gender", + TEST_INDEX_BANK)); + verifySchema( + actual, schema("gender", "string"), schema("vs", "double"), schema("vp", "double")); + verifyDataRows( + actual, + rows("F", 58127404, 38751602.666666664), + rows("M", 261699024.91666666, 196274268.6875)); + } + @Test public void testStddevSampStddevPop() { JSONObject actual = diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java index 6bddd8f61ce..60fb680f116 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java @@ -10,17 +10,17 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.convert.ConverterRule; -import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalTableScan; -import org.opensearch.sql.opensearch.storage.scan.CalciteOpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteEnumerableIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; -/** Rule to convert a {@link CalciteLogicalTableScan} to a {@link CalciteOpenSearchIndexScan}. */ +/** Rule to convert a {@link CalciteLogicalIndexScan} to a {@link CalciteEnumerableIndexScan}. */ public class EnumerableIndexScanRule extends ConverterRule { /** Default configuration. */ public static final Config DEFAULT_CONFIG = Config.INSTANCE .as(Config.class) .withConversion( - CalciteLogicalTableScan.class, + CalciteLogicalIndexScan.class, s -> s.getOsIndex() != null, Convention.NONE, EnumerableConvention.INSTANCE, @@ -34,13 +34,19 @@ protected EnumerableIndexScanRule(Config config) { @Override public boolean matches(RelOptRuleCall call) { - CalciteLogicalTableScan scan = call.rel(0); + CalciteLogicalIndexScan scan = call.rel(0); return scan.getVariablesSet().isEmpty(); } @Override public RelNode convert(RelNode rel) { - final CalciteLogicalTableScan scan = (CalciteLogicalTableScan) rel; - return new CalciteOpenSearchIndexScan(scan.getCluster(), scan.getTable(), scan.getOsIndex()); + final CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan) rel; + return new CalciteEnumerableIndexScan( + scan.getCluster(), + scan.getHints(), + scan.getTable(), + scan.getOsIndex(), + scan.getSchema(), + scan.getPushDownContext()); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java new file mode 100644 index 00000000000..9298f1c9add --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.opensearch.planner.physical; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.immutables.value.Value; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; + +/** Planner rule that push a {@link LogicalAggregate} down to {@link CalciteLogicalIndexScan} */ +@Value.Enclosing +public class OpenSearchAggregateIndexScanRule + extends RelRule { + + /** Creates a OpenSearchAggregateIndexScanRule. */ + protected OpenSearchAggregateIndexScanRule(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + if (call.rels.length == 2) { + // the ordinary variant + final LogicalAggregate aggregate = call.rel(0); + final CalciteLogicalIndexScan scan = call.rel(1); + apply(call, aggregate, scan); + } else { + throw new AssertionError( + String.format( + "The length of rels should be %s but got %s", + this.operands.size(), call.rels.length)); + } + } + + protected void apply( + RelOptRuleCall call, LogicalAggregate aggregate, CalciteLogicalIndexScan scan) { + CalciteLogicalIndexScan newScan = scan.pushDownAggregate(aggregate); + if (newScan != null) { + call.transformTo(newScan); + } + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + /** Config that matches Aggregate on OpenSearchProjectIndexScanRule. */ + Config DEFAULT = + ImmutableOpenSearchAggregateIndexScanRule.Config.builder() + .build() + .withOperandSupplier( + b0 -> + b0.operand(LogicalAggregate.class) + .oneInput( + b1 -> + b1.operand(CalciteLogicalIndexScan.class) + .predicate(OpenSearchIndexScanRule::test) + .noInputs())); + + @Override + default OpenSearchAggregateIndexScanRule toRule() { + return new OpenSearchAggregateIndexScanRule(this); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java index 621bfd8c6fd..f0f1777e9f5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java @@ -5,14 +5,13 @@ package org.opensearch.sql.opensearch.planner.physical; import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.logical.LogicalFilter; import org.immutables.value.Value; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; -import org.opensearch.sql.opensearch.storage.scan.CalciteOpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; -/** Planner rule that push a {@link Filter} down to {@link CalciteOpenSearchIndexScan} */ +/** Planner rule that push a {@link LogicalFilter} down to {@link CalciteLogicalIndexScan} */ @Value.Enclosing public class OpenSearchFilterIndexScanRule extends RelRule { @@ -21,17 +20,12 @@ protected OpenSearchFilterIndexScanRule(Config config) { super(config); } - protected static boolean test(CalciteOpenSearchIndexScan scan) { - final RelOptTable table = scan.getTable(); - return table.unwrap(OpenSearchIndex.class) != null; - } - @Override public void onMatch(RelOptRuleCall call) { if (call.rels.length == 2) { // the ordinary variant - final Filter filter = call.rel(0); - final CalciteOpenSearchIndexScan scan = call.rel(1); + final LogicalFilter filter = call.rel(0); + final CalciteLogicalIndexScan scan = call.rel(1); apply(call, filter, scan); } else { throw new AssertionError( @@ -41,8 +35,8 @@ public void onMatch(RelOptRuleCall call) { } } - protected void apply(RelOptRuleCall call, Filter filter, CalciteOpenSearchIndexScan scan) { - CalciteOpenSearchIndexScan newScan = scan.pushDownFilter(filter); + protected void apply(RelOptRuleCall call, Filter filter, CalciteLogicalIndexScan scan) { + CalciteLogicalIndexScan newScan = scan.pushDownFilter(filter); if (newScan != null) { call.transformTo(newScan); } @@ -51,17 +45,17 @@ protected void apply(RelOptRuleCall call, Filter filter, CalciteOpenSearchIndexS /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { - /** Config that matches Filter on CalciteOpenSearchIndexScan. */ + /** Config that matches Filter on CalciteLogicalIndexScan. */ Config DEFAULT = ImmutableOpenSearchFilterIndexScanRule.Config.builder() .build() .withOperandSupplier( b0 -> - b0.operand(Filter.class) + b0.operand(LogicalFilter.class) .oneInput( b1 -> - b1.operand(CalciteOpenSearchIndexScan.class) - .predicate(OpenSearchFilterIndexScanRule::test) + b1.operand(CalciteLogicalIndexScan.class) + .predicate(OpenSearchIndexScanRule::test) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index a1ee5b787db..cfc03870eb5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -14,9 +14,11 @@ public class OpenSearchIndexRules { OpenSearchProjectIndexScanRule.Config.DEFAULT.toRule(); private static final OpenSearchFilterIndexScanRule FILTER_INDEX_SCAN = OpenSearchFilterIndexScanRule.Config.DEFAULT.toRule(); + private static final OpenSearchAggregateIndexScanRule AGGREGATE_INDEX_SCAN = + OpenSearchAggregateIndexScanRule.Config.DEFAULT.toRule(); public static final List OPEN_SEARCH_INDEX_SCAN_RULES = - ImmutableList.of(PROJECT_INDEX_SCAN, FILTER_INDEX_SCAN); + ImmutableList.of(PROJECT_INDEX_SCAN, FILTER_INDEX_SCAN, AGGREGATE_INDEX_SCAN); // prevent instantiation private OpenSearchIndexRules() {} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java new file mode 100644 index 00000000000..78ab8607764 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java @@ -0,0 +1,16 @@ +package org.opensearch.sql.opensearch.planner.physical; + +import org.apache.calcite.plan.RelOptTable; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; + +public interface OpenSearchIndexScanRule { + + // CalciteOpenSearchIndexScan doesn't allow push-down anymore (except Sort under some strict + // condition) after Aggregate push-down. + static boolean test(CalciteLogicalIndexScan scan) { + if (scan.getPushDownContext().isAggregatePushed()) return false; + final RelOptTable table = scan.getTable(); + return table.unwrap(OpenSearchIndex.class) != null; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchProjectIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchProjectIndexScanRule.java index e865a9eb024..80b1b76244d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchProjectIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchProjectIndexScanRule.java @@ -8,10 +8,10 @@ import java.util.ArrayList; import java.util.List; +import org.apache.calcite.adapter.enumerable.EnumerableProject; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelRule; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; @@ -20,9 +20,9 @@ import org.apache.calcite.util.mapping.Mappings; import org.immutables.value.Value; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; -import org.opensearch.sql.opensearch.storage.scan.CalciteOpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; -/** Planner rule that push a {@link Project} down to {@link CalciteOpenSearchIndexScan} */ +/** Planner rule that push a {@link EnumerableProject} down to {@link CalciteLogicalIndexScan} */ @Value.Enclosing public class OpenSearchProjectIndexScanRule extends RelRule { @@ -31,17 +31,12 @@ protected OpenSearchProjectIndexScanRule(Config config) { super(config); } - protected static boolean test(CalciteOpenSearchIndexScan scan) { - final RelOptTable table = scan.getTable(); - return table.unwrap(OpenSearchIndex.class) != null; - } - @Override public void onMatch(RelOptRuleCall call) { if (call.rels.length == 2) { // the ordinary variant - final Project project = call.rel(0); - final CalciteOpenSearchIndexScan scan = call.rel(1); + final EnumerableProject project = call.rel(0); + final CalciteLogicalIndexScan scan = call.rel(1); apply(call, project, scan); } else { throw new AssertionError( @@ -51,10 +46,13 @@ public void onMatch(RelOptRuleCall call) { } } - protected void apply(RelOptRuleCall call, Project project, CalciteOpenSearchIndexScan scan) { + protected void apply( + RelOptRuleCall call, EnumerableProject project, CalciteLogicalIndexScan scan) { final RelOptTable table = scan.getTable(); requireNonNull(table.unwrap(OpenSearchIndex.class)); + // TODO: support script pushdown for project instead of only reference + // https://github.com/opensearch-project/sql/issues/3387 final List selectedColumns = new ArrayList<>(); final RexVisitorImpl visitor = new RexVisitorImpl(true) { @@ -70,7 +68,7 @@ public Void visitInputRef(RexInputRef inputRef) { // Only do push down when an actual projection happens if (!selectedColumns.isEmpty() && selectedColumns.size() != scan.getRowType().getFieldCount()) { Mapping mapping = Mappings.target(selectedColumns, scan.getRowType().getFieldCount()); - CalciteOpenSearchIndexScan newScan = scan.pushDownProject(selectedColumns); + CalciteLogicalIndexScan newScan = scan.pushDownProject(selectedColumns); final List newProjectRexNodes = RexUtil.apply(mapping, project.getProjects()); if (RexUtil.isIdentity(newProjectRexNodes, newScan.getRowType())) { @@ -90,11 +88,11 @@ public interface Config extends RelRule.Config { .build() .withOperandSupplier( b0 -> - b0.operand(Project.class) + b0.operand(EnumerableProject.class) .oneInput( b1 -> - b1.operand(CalciteOpenSearchIndexScan.class) - .predicate(OpenSearchProjectIndexScanRule::test) + b1.operand(CalciteLogicalIndexScan.class) + .predicate(OpenSearchIndexScanRule::test) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java new file mode 100644 index 00000000000..f370c206af4 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -0,0 +1,269 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.sql.opensearch.request; + +import static java.util.Objects.requireNonNull; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_INDEX; + +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.sql.SqlKind; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactories.Builder; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.missing.MissingOrder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; +import org.opensearch.search.aggregations.support.ValueType; +import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.request.PredicateAnalyzer.NamedFieldExpression; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; + +/** + * Aggregate analyzer. Convert aggregate to AggregationBuilder {@link AggregationBuilder} and its + * related Parser {@link OpenSearchAggregationResponseParser}. + */ +public class AggregateAnalyzer { + + /** How many composite buckets should be returned. */ + public static final int AGGREGATION_BUCKET_SIZE = 1000; + + /** metadata field used when there is no argument. Only apply to COUNT. */ + private static final String METADATA_FIELD = "_index"; + + /** Internal exception. */ + @SuppressWarnings("serial") + public static final class AggregateAnalyzerException extends RuntimeException { + + AggregateAnalyzerException(String message) { + super(message); + } + + AggregateAnalyzerException(Throwable cause) { + super(cause); + } + } + + /** + * Exception that is thrown when a {@link Aggregate} cannot be processed (or converted into an + * OpenSearch aggregate query). + */ + public static class ExpressionNotAnalyzableException extends Exception { + ExpressionNotAnalyzableException(String message, Throwable cause) { + super(message, cause); + } + } + + private AggregateAnalyzer() {} + + // TODO: should we support filter aggregation? For PPL, we don't have filter in stats command + // TODO: support script pushdown for aggregation. Calcite doesn't expression in its AggregateCall + // or GroupSet + // https://github.com/opensearch-project/sql/issues/3386 + // + public static Pair, OpenSearchAggregationResponseParser> analyze( + Aggregate aggregate, + List schema, + Map typeMapping, + List outputFields) + throws ExpressionNotAnalyzableException { + requireNonNull(aggregate, "aggregate"); + try { + List groupList = aggregate.getGroupSet().asList(); + FieldExpressionCreator fieldExpressionCreator = + fieldIndex -> new NamedFieldExpression(fieldIndex, schema, typeMapping); + // Process all aggregate calls + Pair> builderAndParser = + processAggregateCalls( + groupList.size(), aggregate.getAggCallList(), fieldExpressionCreator, outputFields); + Builder metricBuilder = builderAndParser.getLeft(); + List metricParserList = builderAndParser.getRight(); + + if (aggregate.getGroupSet().isEmpty()) { + return Pair.of( + ImmutableList.copyOf(metricBuilder.getAggregatorFactories()), + new NoBucketAggregationParser(metricParserList)); + } else { + List> buckets = + createCompositeBuckets(groupList, fieldExpressionCreator); + return Pair.of( + Collections.singletonList( + AggregationBuilders.composite("composite_buckets", buckets) + .subAggregations(metricBuilder) + .size(AGGREGATION_BUCKET_SIZE)), + new CompositeAggregationParser(metricParserList)); + } + } catch (Throwable e) { + Throwables.throwIfInstanceOf(e, UnsupportedOperationException.class); + throw new ExpressionNotAnalyzableException("Can't convert " + aggregate, e); + } + } + + private static Pair> processAggregateCalls( + int groupOffset, + List aggCalls, + FieldExpressionCreator fieldExpressionCreator, + List outputFields) { + Builder metricBuilder = new AggregatorFactories.Builder(); + List metricParserList = new ArrayList<>(); + + for (int i = 0; i < aggCalls.size(); i++) { + AggregateCall aggCall = aggCalls.get(i); + String argStr = + aggCall.getAggregation().kind == SqlKind.COUNT && aggCall.getArgList().isEmpty() + ? METADATA_FIELD_INDEX + : fieldExpressionCreator + .create(aggCall.getArgList().getFirst()) + .getReferenceForTermQuery(); + String aggField = outputFields.get(groupOffset + i); + + Pair, MetricParser> builderAndParser = + createAggregationBuilderAndParser(aggCall, argStr, aggField); + metricBuilder.addAggregator(builderAndParser.getLeft()); + metricParserList.add(builderAndParser.getRight()); + } + return Pair.of(metricBuilder, metricParserList); + } + + private interface FieldExpressionCreator { + NamedFieldExpression create(int fieldIndex); + } + + private static Pair, MetricParser> + createAggregationBuilderAndParser(AggregateCall aggCall, String argStr, String aggField) { + if (aggCall.isDistinct()) { + return createDistinctAggregation(aggCall, argStr, aggField); + } else { + return createRegularAggregation(aggCall, argStr, aggField); + } + } + + private static Pair, MetricParser> createDistinctAggregation( + AggregateCall aggCall, String argStr, String aggField) { + + return switch (aggCall.getAggregation().kind) { + case COUNT -> Pair.of( + AggregationBuilders.cardinality(aggField).field(argStr), new SingleValueParser(aggField)); + default -> throw new AggregateAnalyzer.AggregateAnalyzerException( + String.format("unsupported distinct aggregator %s", aggCall.getAggregation())); + }; + } + + private static Pair, MetricParser> createRegularAggregation( + AggregateCall aggCall, String argStr, String aggField) { + + return switch (aggCall.getAggregation().kind) { + case AVG -> Pair.of( + AggregationBuilders.avg(aggField).field(argStr), new SingleValueParser(aggField)); + case SUM -> Pair.of( + AggregationBuilders.sum(aggField).field(argStr), new SingleValueParser(aggField)); + case COUNT -> Pair.of( + AggregationBuilders.count(aggField).field(argStr), new SingleValueParser(aggField)); + case MIN -> Pair.of( + AggregationBuilders.min(aggField).field(argStr), new SingleValueParser(aggField)); + case MAX -> Pair.of( + AggregationBuilders.max(aggField).field(argStr), new SingleValueParser(aggField)); + case VAR_SAMP -> Pair.of( + AggregationBuilders.extendedStats(aggField).field(argStr), + new StatsParser(ExtendedStats::getVarianceSampling, aggField)); + case VAR_POP -> Pair.of( + AggregationBuilders.extendedStats(aggField).field(argStr), + new StatsParser(ExtendedStats::getVariancePopulation, aggField)); + case STDDEV_SAMP -> Pair.of( + AggregationBuilders.extendedStats(aggField).field(argStr), + new StatsParser(ExtendedStats::getStdDeviationSampling, aggField)); + case STDDEV_POP -> Pair.of( + AggregationBuilders.extendedStats(aggField).field(argStr), + new StatsParser(ExtendedStats::getStdDeviationPopulation, aggField)); + // TODO: below UDAF should support push down once implemented + // https://github.com/opensearch-project/sql/issues/3385 + // case take + // case percentile + // case percentile_approx + default -> throw new AggregateAnalyzerException( + String.format("unsupported aggregator %s", aggCall.getAggregation())); + }; + } + + private static List> createCompositeBuckets( + List groupList, FieldExpressionCreator fieldExpressionCreator) { + + ImmutableList.Builder> resultBuilder = ImmutableList.builder(); + + for (int groupIndex : groupList) { + NamedFieldExpression groupExpr = fieldExpressionCreator.create(groupIndex); + + // TODO: support histogram bucket(i.e. PPL span expression) + // https://github.com/opensearch-project/sql/issues/3384 + CompositeValuesSourceBuilder sourceBuilder = createTermsSourceBuilder(groupExpr); + + resultBuilder.add(sourceBuilder); + } + + return resultBuilder.build(); + } + + private static CompositeValuesSourceBuilder createTermsSourceBuilder( + NamedFieldExpression groupExpr) { + + CompositeValuesSourceBuilder sourceBuilder = + new TermsValuesSourceBuilder(groupExpr.getRootName()) + .missingBucket(true) + // TODO: use Sort's option if there is Sort push-down into aggregation + // https://github.com/opensearch-project/sql/issues/3380 + .missingOrder(MissingOrder.FIRST) + .order(SortOrder.ASC) + .field(groupExpr.getReferenceForTermQuery()); + + // Time types values are converted to LONG in ExpressionAggregationScript::execute + if (List.of(TIMESTAMP, TIME, DATE) + .contains(groupExpr.getOpenSearchDataType().getExprCoreType())) { + sourceBuilder.userValuetypeHint(ValueType.LONG); + } + + return sourceBuilder; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java index 73c1ffe2410..21273c743e7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java @@ -141,7 +141,7 @@ public static QueryBuilder analyze( } } - /** Traverses {@link RexNode} tree and builds ES query. */ + /** Traverses {@link RexNode} tree and builds OpenSearch query. */ private static class Visitor extends RexVisitorImpl { List schema; @@ -975,6 +975,12 @@ static final class NamedFieldExpression implements TerminalExpression { private final String name; private final OpenSearchDataType type; + NamedFieldExpression( + int refIndex, List schema, Map typeMapping) { + this.name = refIndex >= schema.size() ? null : schema.get(refIndex); + this.type = typeMapping.get(name); + } + private NamedFieldExpression() { this.name = null; this.type = null; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index c77bb3c94d7..f4903e87a41 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -33,7 +33,7 @@ import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; -import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalTableScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexEnumerator; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; @@ -93,8 +93,7 @@ public OpenSearchIndex(OpenSearchClient client, Settings settings, String indexN @Override public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable relOptTable) { final RelOptCluster cluster = context.getCluster(); - // return new CalciteOpenSearchIndexScan(cluster, relOptTable, this); - return new CalciteLogicalTableScan(cluster, relOptTable, this); + return new CalciteLogicalIndexScan(cluster, relOptTable, this); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java new file mode 100644 index 00000000000..b748ee23b37 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.util.List; +import org.apache.calcite.adapter.enumerable.EnumerableConvention; +import org.apache.calcite.adapter.enumerable.EnumerableRel; +import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor; +import org.apache.calcite.adapter.enumerable.PhysType; +import org.apache.calcite.adapter.enumerable.PhysTypeImpl; +import org.apache.calcite.linq4j.AbstractEnumerable; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.linq4j.tree.Blocks; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.hint.RelHint; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.opensearch.sql.calcite.plan.OpenSearchRules; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; + +/** Relational expression representing a scan of an OpenSearchIndex type. */ +public class CalciteEnumerableIndexScan extends CalciteIndexScan implements EnumerableRel { + private static final Logger LOG = LogManager.getLogger(CalciteEnumerableIndexScan.class); + + /** + * Creates an CalciteOpenSearchIndexScan. + * + * @param cluster Cluster + * @param table Table + * @param osIndex OpenSearch index + */ + public CalciteEnumerableIndexScan( + RelOptCluster cluster, + List hints, + RelOptTable table, + OpenSearchIndex osIndex, + RelDataType schema, + PushDownContext pushDownContext) { + super( + cluster, + cluster.traitSetOf(EnumerableConvention.INSTANCE), + hints, + table, + osIndex, + schema, + pushDownContext); + } + + @Override + public void register(RelOptPlanner planner) { + for (RelOptRule rule : OpenSearchRules.OPEN_SEARCH_OPT_RULES) { + planner.addRule(rule); + } + + // remove this rule otherwise opensearch can't correctly interpret approx_count_distinct() + // it is converted to cardinality aggregation in OpenSearch + planner.removeRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES); + } + + @Override + public Result implement(EnumerableRelImplementor implementor, Prefer pref) { + /* In Calcite enumerable operators, row of single column will be optimized to a scalar value. + * See {@link PhysTypeImpl}. + * Since we need to combine this operator with their original ones, + * let's follow this convention to apply the optimization here and ensure `scan` method + * returns the correct data format for single column rows. + * See {@link OpenSearchIndexEnumerator} + */ + PhysType physType = + PhysTypeImpl.of(implementor.getTypeFactory(), getRowType(), pref.preferArray()); + + Expression scanOperator = implementor.stash(this, CalciteEnumerableIndexScan.class); + return implementor.result(physType, Blocks.toBlock(Expressions.call(scanOperator, "scan"))); + } + + public Enumerable<@Nullable Object> scan() { + OpenSearchRequestBuilder requestBuilder = osIndex.createRequestBuilder(); + pushDownContext.forEach(action -> action.apply(requestBuilder)); + return new AbstractEnumerable<>() { + @Override + public Enumerator enumerator() { + return new OpenSearchIndexEnumerator( + osIndex.getClient(), + List.copyOf(getRowType().getFieldNames()), + requestBuilder.getMaxResponseSize(), + osIndex.buildRequest(requestBuilder)); + } + }; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScan.java new file mode 100644 index 00000000000..0232ddf406c --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScan.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static java.util.Objects.requireNonNull; + +import java.util.ArrayDeque; +import java.util.List; +import lombok.Getter; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.hint.RelHint; +import org.apache.calcite.rel.type.RelDataType; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; + +/** Relational expression representing a scan of an OpenSearchIndex type. */ +@Getter +public abstract class CalciteIndexScan extends TableScan { + protected final OpenSearchIndex osIndex; + // The schema of this scan operator, it's initialized with the row type of the table, but may be + // changed by push down operations. + protected final RelDataType schema; + // This context maintains all the push down actions, which will be applied to the requestBuilder + // when it begins to scan data from OpenSearch. + // Because OpenSearchRequestBuilder doesn't support deep copy while we want to keep the + // requestBuilder independent among different plans produced in the optimization process, + // so we cannot apply these actions right away. + protected final PushDownContext pushDownContext; + + protected CalciteIndexScan( + RelOptCluster cluster, + RelTraitSet traitSet, + List hints, + RelOptTable table, + OpenSearchIndex osIndex, + RelDataType schema, + PushDownContext pushDownContext) { + super(cluster, traitSet, hints, table); + this.osIndex = requireNonNull(osIndex, "OpenSearch index"); + this.schema = schema; + this.pushDownContext = pushDownContext; + } + + @Override + public RelDataType deriveRowType() { + return this.schema; + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw) + .itemIf("PushDownContext", pushDownContext, !pushDownContext.isEmpty()); + } + + // TODO: should we consider equivalent among PushDownContexts with different push down sequence? + public static class PushDownContext extends ArrayDeque { + private boolean isAggregatePushed = false; + + @Override + public PushDownContext clone() { + return (PushDownContext) super.clone(); + } + + @Override + public boolean add(PushDownAction pushDownAction) { + // Defense check. It should never do push down to this context after aggregate push-down. + assert !isAggregatePushed : "Aggregate has already been pushed!"; + if (pushDownAction.type == PushDownType.AGGREGATION) { + isAggregatePushed = true; + } + return super.add(pushDownAction); + } + + public boolean isAggregatePushed() { + if (isAggregatePushed) return true; + isAggregatePushed = !isEmpty() && super.peekLast().type == PushDownType.AGGREGATION; + return isAggregatePushed; + } + } + + protected enum PushDownType { + FILTER, + PROJECT, + AGGREGATION, + // SORT, + // LIMIT, + // HIGHLIGHT, + // NESTED + } + + public record PushDownAction(PushDownType type, Object digest, AbstractAction action) { + static PushDownAction of(PushDownType type, Object digest, AbstractAction action) { + return new PushDownAction(type, digest, action); + } + + public String toString() { + return type + ":" + digest; + } + + public void apply(OpenSearchRequestBuilder requestBuilder) { + action.apply(requestBuilder); + } + } + + public interface AbstractAction { + void apply(OpenSearchRequestBuilder requestBuilder); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java new file mode 100644 index 00000000000..84f9d85f70f --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.hint.RelHint; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.planner.physical.EnumerableIndexScanRule; +import org.opensearch.sql.opensearch.planner.physical.OpenSearchIndexRules; +import org.opensearch.sql.opensearch.request.AggregateAnalyzer; +import org.opensearch.sql.opensearch.request.PredicateAnalyzer; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; + +@Getter +public class CalciteLogicalIndexScan extends CalciteIndexScan { + private static final Logger LOG = LogManager.getLogger(CalciteLogicalIndexScan.class); + + public CalciteLogicalIndexScan( + RelOptCluster cluster, RelOptTable table, OpenSearchIndex osIndex) { + this( + cluster, + cluster.traitSetOf(Convention.NONE), + ImmutableList.of(), + table, + osIndex, + table.getRowType(), + new PushDownContext()); + } + + protected CalciteLogicalIndexScan( + RelOptCluster cluster, + RelTraitSet traitSet, + List hints, + RelOptTable table, + OpenSearchIndex osIndex, + RelDataType schema, + PushDownContext pushDownContext) { + super(cluster, traitSet, hints, table, osIndex, schema, pushDownContext); + } + + public CalciteLogicalIndexScan copyWithNewSchema(RelDataType schema) { + // Do shallow copy for requestBuilder, thus requestBuilder among different plans produced in the + // optimization process won't affect each other. + return new CalciteLogicalIndexScan( + getCluster(), traitSet, hints, table, osIndex, schema, pushDownContext.clone()); + } + + @Override + public void register(RelOptPlanner planner) { + super.register(planner); + planner.addRule(EnumerableIndexScanRule.DEFAULT_CONFIG.toRule()); + if (osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_PUSHDOWN_ENABLED)) { + for (RelOptRule rule : OpenSearchIndexRules.OPEN_SEARCH_INDEX_SCAN_RULES) { + planner.addRule(rule); + } + } + } + + public CalciteLogicalIndexScan pushDownFilter(Filter filter) { + try { + CalciteLogicalIndexScan newScan = this.copyWithNewSchema(filter.getRowType()); + List schema = this.getRowType().getFieldNames(); + Map typeMapping = this.osIndex.getFieldOpenSearchTypes(); + QueryBuilder filterBuilder = + PredicateAnalyzer.analyze(filter.getCondition(), schema, typeMapping); + newScan.pushDownContext.add( + PushDownAction.of( + PushDownType.FILTER, + filter.getCondition(), + requestBuilder -> requestBuilder.pushDownFilter(filterBuilder))); + + // TODO: handle the case where condition contains a score function + return newScan; + } catch (Exception e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the filter condition {}", filter.getCondition(), e); + } else { + LOG.warn("Cannot pushdown the filter condition {}, ", filter.getCondition()); + } + } + return null; + } + + /** + * When pushing down a project, we need to create a new CalciteLogicalIndexScan with the updated + * schema since we cannot override getRowType() which is defined to be final. + */ + public CalciteLogicalIndexScan pushDownProject(List selectedColumns) { + final RelDataTypeFactory.Builder builder = getCluster().getTypeFactory().builder(); + final List fieldList = this.getRowType().getFieldList(); + for (int project : selectedColumns) { + builder.add(fieldList.get(project)); + } + RelDataType newSchema = builder.build(); + CalciteLogicalIndexScan newScan = this.copyWithNewSchema(newSchema); + newScan.pushDownContext.add( + PushDownAction.of( + PushDownType.PROJECT, + newSchema.getFieldNames(), + requestBuilder -> + requestBuilder.pushDownProjectStream(newSchema.getFieldNames().stream()))); + return newScan; + } + + public CalciteLogicalIndexScan pushDownAggregate(Aggregate aggregate) { + try { + CalciteLogicalIndexScan newScan = this.copyWithNewSchema(aggregate.getRowType()); + List schema = this.getRowType().getFieldNames(); + Map typeMapping = this.osIndex.getFieldOpenSearchTypes(); + List outputFields = aggregate.getRowType().getFieldNames(); + final Pair, OpenSearchAggregationResponseParser> aggregationBuilder = + AggregateAnalyzer.analyze(aggregate, schema, typeMapping, outputFields); + Map extendedTypeMapping = + aggregate.getRowType().getFieldList().stream() + .collect( + Collectors.toMap( + RelDataTypeField::getName, + field -> + OpenSearchDataType.of( + OpenSearchTypeFactory.convertRelDataTypeToExprType( + field.getType())))); + newScan.pushDownContext.add( + PushDownAction.of( + PushDownType.AGGREGATION, + aggregate, + requestBuilder -> { + requestBuilder.pushDownAggregation(aggregationBuilder); + requestBuilder.pushTypeMapping(extendedTypeMapping); + })); + return newScan; + } catch (Exception e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the aggregate {}", aggregate, e); + } else { + LOG.warn("Cannot pushdown the aggregate {}, ", aggregate); + } + } + return null; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalTableScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalTableScan.java deleted file mode 100644 index 30b2193d927..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalTableScan.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import com.google.common.collect.ImmutableList; -import java.util.List; -import lombok.Getter; -import org.apache.calcite.plan.Convention; -import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.plan.RelOptPlanner; -import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.plan.RelTraitSet; -import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.hint.RelHint; -import org.opensearch.sql.opensearch.planner.physical.EnumerableIndexScanRule; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; - -@Getter -public class CalciteLogicalTableScan extends TableScan { - private final OpenSearchIndex osIndex; - - protected CalciteLogicalTableScan( - RelOptCluster cluster, - RelTraitSet traitSet, - List hints, - RelOptTable table, - OpenSearchIndex osIndex) { - super(cluster, traitSet, hints, table); - this.osIndex = osIndex; - } - - public CalciteLogicalTableScan( - RelOptCluster cluster, RelOptTable table, OpenSearchIndex osIndex) { - this(cluster, cluster.traitSetOf(Convention.NONE), ImmutableList.of(), table, osIndex); - } - - @Override - public void register(RelOptPlanner planner) { - super.register(planner); - planner.addRule(EnumerableIndexScanRule.DEFAULT_CONFIG.toRule()); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java deleted file mode 100644 index 1c84447ab09..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteOpenSearchIndexScan.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import static java.util.Objects.requireNonNull; - -import java.util.ArrayDeque; -import java.util.List; -import java.util.Map; -import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor; -import org.apache.calcite.adapter.enumerable.PhysType; -import org.apache.calcite.adapter.enumerable.PhysTypeImpl; -import org.apache.calcite.linq4j.AbstractEnumerable; -import org.apache.calcite.linq4j.Enumerable; -import org.apache.calcite.linq4j.Enumerator; -import org.apache.calcite.linq4j.tree.Blocks; -import org.apache.calcite.linq4j.tree.Expression; -import org.apache.calcite.linq4j.tree.Expressions; -import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.plan.RelOptPlanner; -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.plan.RelTraitSet; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.RelWriter; -import org.apache.calcite.rel.core.Filter; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.sql.calcite.plan.OpenSearchTableScan; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.planner.physical.OpenSearchIndexRules; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; -import org.opensearch.sql.opensearch.request.PredicateAnalyzer; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; - -/** Relational expression representing a scan of an OpenSearchIndex type. */ -public class CalciteOpenSearchIndexScan extends OpenSearchTableScan { - private static final Logger LOG = LogManager.getLogger(CalciteOpenSearchIndexScan.class); - - private final OpenSearchIndex osIndex; - // The schema of this scan operator, it's initialized with the row type of the table, but may be - // changed by push down operations. - private final RelDataType schema; - // This context maintains all the push down actions, which will be applied to the requestBuilder - // when it begins to scan data from OpenSearch. - // Because OpenSearchRequestBuilder doesn't support deep copy while we want to keep the - // requestBuilder independent among different plans produced in the optimization process, - // so we cannot apply these actions right away. - private final PushDownContext pushDownContext; - - /** - * Creates an CalciteOpenSearchIndexScan. - * - * @param cluster Cluster - * @param table Table - * @param index OpenSearch index - */ - public CalciteOpenSearchIndexScan( - RelOptCluster cluster, RelOptTable table, OpenSearchIndex index) { - this(cluster, table, index, table.getRowType(), new PushDownContext()); - } - - private CalciteOpenSearchIndexScan( - RelOptCluster cluster, - RelOptTable table, - OpenSearchIndex index, - RelDataType schema, - PushDownContext pushDownContext) { - super(cluster, table); - this.osIndex = requireNonNull(index, "OpenSearch index"); - this.schema = schema; - this.pushDownContext = pushDownContext; - } - - public CalciteOpenSearchIndexScan copy() { - return new CalciteOpenSearchIndexScan( - getCluster(), table, osIndex, this.schema, pushDownContext.clone()); - } - - public CalciteOpenSearchIndexScan copyWithNewSchema(RelDataType schema) { - // Do shallow copy for requestBuilder, thus requestBuilder among different plans produced in the - // optimization process won't affect each other. - return new CalciteOpenSearchIndexScan( - getCluster(), table, osIndex, schema, pushDownContext.clone()); - } - - @Override - public RelNode copy(RelTraitSet traitSet, List inputs) { - assert inputs.isEmpty(); - return new CalciteOpenSearchIndexScan(getCluster(), table, osIndex); - } - - @Override - public RelWriter explainTerms(RelWriter pw) { - return super.explainTerms(pw) - .itemIf("PushDownContext", pushDownContext, !pushDownContext.isEmpty()); - } - - @Override - public void register(RelOptPlanner planner) { - super.register(planner); - if (osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_PUSHDOWN_ENABLED)) { - for (RelOptRule rule : OpenSearchIndexRules.OPEN_SEARCH_INDEX_SCAN_RULES) { - planner.addRule(rule); - } - } - } - - @Override - public RelDataType deriveRowType() { - return this.schema; - } - - @Override - public Result implement(EnumerableRelImplementor implementor, Prefer pref) { - /* In Calcite enumerable operators, row of single column will be optimized to a scalar value. - * See {@link PhysTypeImpl}. - * Since we need to combine this operator with their original ones, - * let's follow this convention to apply the optimization here and ensure `scan` method - * returns the correct data format for single column rows. - * See {@link OpenSearchIndexEnumerator} - */ - PhysType physType = - PhysTypeImpl.of(implementor.getTypeFactory(), getRowType(), pref.preferArray()); - - Expression scanOperator = implementor.stash(this, CalciteOpenSearchIndexScan.class); - return implementor.result(physType, Blocks.toBlock(Expressions.call(scanOperator, "scan"))); - } - - public Enumerable<@Nullable Object> scan() { - OpenSearchRequestBuilder requestBuilder = osIndex.createRequestBuilder(); - pushDownContext.forEach(action -> action.apply(requestBuilder)); - return new AbstractEnumerable<>() { - @Override - public Enumerator enumerator() { - return new OpenSearchIndexEnumerator( - osIndex.getClient(), - List.copyOf(getRowType().getFieldNames()), - requestBuilder.getMaxResponseSize(), - osIndex.buildRequest(requestBuilder)); - } - }; - } - - public CalciteOpenSearchIndexScan pushDownFilter(Filter filter) { - try { - CalciteOpenSearchIndexScan newScan = this.copyWithNewSchema(filter.getRowType()); - List schema = this.getRowType().getFieldNames(); - Map typeMapping = this.osIndex.getFieldOpenSearchTypes(); - QueryBuilder filterBuilder = - PredicateAnalyzer.analyze(filter.getCondition(), schema, typeMapping); - newScan.pushDownContext.add( - PushDownAction.of( - PushDownType.FILTER, - filter.getCondition(), - requestBuilder -> requestBuilder.pushDownFilter(filterBuilder))); - - // TODO: handle the case where condition contains a score function - return newScan; - } catch (Exception e) { - LOG.warn("Cannot pushdown the filter condition {}, ", filter.getCondition()); - } - return null; - } - - /** - * When pushing down a project, we need to create a new CalciteOpenSearchIndexScan with the - * updated schema since we cannot override getRowType() which is defined to be final. - */ - public CalciteOpenSearchIndexScan pushDownProject(List selectedColumns) { - final RelDataTypeFactory.Builder builder = getCluster().getTypeFactory().builder(); - final List fieldList = this.getRowType().getFieldList(); - for (int project : selectedColumns) { - builder.add(fieldList.get(project)); - } - RelDataType newSchema = builder.build(); - CalciteOpenSearchIndexScan newScan = this.copyWithNewSchema(newSchema); - newScan.pushDownContext.add( - PushDownAction.of( - PushDownType.PROJECT, - newSchema.getFieldNames(), - requestBuilder -> - requestBuilder.pushDownProjectStream(newSchema.getFieldNames().stream()))); - return newScan; - } - - // TODO: should we consider equivalent among PushDownContexts with different push down sequence? - static class PushDownContext extends ArrayDeque { - @Override - public PushDownContext clone() { - return (PushDownContext) super.clone(); - } - } - - private enum PushDownType { - FILTER, - PROJECT, - // AGGREGATION, - // SORT, - // LIMIT, - // HIGHLIGHT, - // NESTED - } - - private record PushDownAction(PushDownType type, Object digest, AbstractAction action) { - static PushDownAction of(PushDownType type, Object digest, AbstractAction action) { - return new PushDownAction(type, digest, action); - } - - public String toString() { - return type + ":" + digest; - } - - void apply(OpenSearchRequestBuilder requestBuilder) { - action.apply(requestBuilder); - } - } - - private interface AbstractAction { - void apply(OpenSearchRequestBuilder requestBuilder); - } -}