diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index c3c1f43f8e142..974a49d9d9f34 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -186,6 +186,9 @@ public class Analysis private Optional callTarget = Optional.empty(); private Optional targetQuery = Optional.empty(); + // for create vector index + private Optional createVectorIndexAnalysis = Optional.empty(); + // for create table private Optional createTableDestination = Optional.empty(); private Map createTableProperties = ImmutableMap.of(); @@ -700,6 +703,16 @@ public Optional getCreateTableDestination() return createTableDestination; } + public void setCreateVectorIndexAnalysis(CreateVectorIndexAnalysis analysis) + { + this.createVectorIndexAnalysis = Optional.of(analysis); + } + + public Optional getCreateVectorIndexAnalysis() + { + return createVectorIndexAnalysis; + } + public Optional getProcedureName() { return procedureName; @@ -1937,4 +1950,53 @@ public Scope getTargetTableScope() return targetTableScope; } } + + @Immutable + public static final class CreateVectorIndexAnalysis + { + private final QualifiedObjectName sourceTableName; + private final QualifiedObjectName targetTableName; + private final List columns; + private final Map properties; + private final Optional updatingFor; + + public CreateVectorIndexAnalysis( + QualifiedObjectName sourceTableName, + QualifiedObjectName targetTableName, + List columns, + Map properties, + Optional updatingFor) + { + this.sourceTableName = requireNonNull(sourceTableName, "sourceTableName is null"); + this.targetTableName = requireNonNull(targetTableName, "targetTableName is null"); + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + this.properties = ImmutableMap.copyOf(requireNonNull(properties, "properties is null")); + this.updatingFor = requireNonNull(updatingFor, "updatingFor is null"); + } + + public QualifiedObjectName getSourceTableName() + { + return sourceTableName; + } + + public QualifiedObjectName getTargetTableName() + { + return targetTableName; + } + + public List getColumns() + { + return columns; + } + + public Map getProperties() + { + return properties; + } + + public Optional getUpdatingFor() + { + return updatingFor; + } + } } diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java index 1be20437e8a1c..dfeb0b7899e37 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateTag; import com.facebook.presto.sql.tree.CreateType; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Deallocate; import com.facebook.presto.sql.tree.Delete; @@ -107,6 +108,7 @@ private StatementUtils() {} builder.put(CreateTableAsSelect.class, QueryType.INSERT); builder.put(Insert.class, QueryType.INSERT); builder.put(RefreshMaterializedView.class, QueryType.INSERT); + builder.put(CreateVectorIndex.class, QueryType.INSERT); builder.put(Delete.class, QueryType.DELETE); builder.put(Update.class, QueryType.UPDATE); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java index cff7bec518dde..8b57720212529 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java @@ -22,6 +22,7 @@ import com.facebook.presto.metadata.AnalyzeTableHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.MergeHandle; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; @@ -36,6 +37,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -122,6 +124,11 @@ private static Optional createWriterTarget(Optional new VerifyException("mergeHandle is absent: " + target.getClass().getSimpleName())))); } + if (target instanceof TableWriterNode.CreateVectorIndexReference) { + throw new PrestoException(NOT_SUPPORTED, + "This connector does not support creating vector indexes. " + + "The connector must provide a ConnectorPlanOptimizer to handle CREATE VECTOR INDEX."); + } throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 4bca843875cef..da001aecc12bb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -62,6 +62,7 @@ import com.facebook.presto.operator.aggregation.ClassificationThresholdsAggregation; import com.facebook.presto.operator.aggregation.CountAggregation; import com.facebook.presto.operator.aggregation.CountIfAggregation; +import com.facebook.presto.operator.aggregation.CreateVectorIndexAggregation; import com.facebook.presto.operator.aggregation.DefaultApproximateCountDistinctAggregation; import com.facebook.presto.operator.aggregation.DoubleCorrelationAggregation; import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation; @@ -710,6 +711,7 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregate(GeometryUnionAgg.class) .aggregate(SpatialPartitioningAggregateFunction.class) .aggregate(SpatialPartitioningInternalAggregateFunction.class) + .aggregates(CreateVectorIndexAggregation.class) .aggregates(CountAggregation.class) .aggregates(VarianceAggregation.class) .aggregates(CentralMomentsAggregation.class) diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java index 6165940a70f92..d3705facea051 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; @@ -310,6 +311,22 @@ public Optional finishCreateTable( return delegate.finishCreateTable(session, tableHandle, fragments, computedStatistics); } + @Override + public OutputTableHandle beginCreateVectorIndex(Session session, String catalogName, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + return delegate.beginCreateVectorIndex(session, catalogName, indexMetadata, layout, sourceTableName); + } + + @Override + public Optional finishCreateVectorIndex( + Session session, + OutputTableHandle tableHandle, + Collection fragments, + Collection computedStatistics) + { + return delegate.finishCreateVectorIndex(session, tableHandle, fragments, computedStatistics); + } + @Override public Optional getInsertLayout(Session session, TableHandle target) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index edb2285ba23d5..25c736829fc5f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableLayoutFilterCoverage; @@ -69,6 +70,7 @@ import java.util.OptionalLong; import java.util.Set; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.TableLayoutFilterCoverage.NOT_APPLICABLE; public interface Metadata @@ -270,6 +272,22 @@ public interface Metadata */ Optional finishCreateTable(Session session, OutputTableHandle tableHandle, Collection fragments, Collection computedStatistics); + /** + * Begin the atomic creation of a vector index with data. + */ + default OutputTableHandle beginCreateVectorIndex(Session session, String catalogName, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support creating vector indexes"); + } + + /** + * Finish a vector index creation with data after the data is written. + */ + default Optional finishCreateVectorIndex(Session session, OutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support creating vector indexes"); + } + Optional getInsertLayout(Session session, TableHandle target); /** diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java index 183cbe91f252b..3628f915b47e3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -957,6 +957,27 @@ public Optional finishCreateTable(Session session, Outp return metadata.finishCreateTable(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), fragments, computedStatistics); } + @Override + public OutputTableHandle beginCreateVectorIndex(Session session, String catalogName, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogName); + ConnectorId connectorId = catalogMetadata.getConnectorId(); + ConnectorMetadata metadata = catalogMetadata.getMetadata(); + + ConnectorTransactionHandle transactionHandle = catalogMetadata.getTransactionHandleFor(connectorId); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + ConnectorOutputTableHandle handle = metadata.beginCreateVectorIndex(connectorSession, indexMetadata, layout.map(NewTableLayout::getLayout), sourceTableName); + return new OutputTableHandle(connectorId, transactionHandle, handle); + } + + @Override + public Optional finishCreateVectorIndex(Session session, OutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + return metadata.finishCreateVectorIndex(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), fragments, computedStatistics); + } + @Override public InsertTableHandle beginInsert(Session session, TableHandle tableHandle) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManagerStats.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManagerStats.java index 3f4c80c01cfda..2678398c94328 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManagerStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManagerStats.java @@ -55,6 +55,8 @@ public class MetadataManagerStats private final AtomicLong getNewTableLayoutCalls = new AtomicLong(); private final AtomicLong beginCreateTableCalls = new AtomicLong(); private final AtomicLong finishCreateTableCalls = new AtomicLong(); + private final AtomicLong beginCreateVectorIndexCalls = new AtomicLong(); + private final AtomicLong finishCreateVectorIndexCalls = new AtomicLong(); private final AtomicLong getInsertLayoutCalls = new AtomicLong(); private final AtomicLong getStatisticsCollectionMetadataForWriteCalls = new AtomicLong(); private final AtomicLong getStatisticsCollectionMetadataCalls = new AtomicLong(); @@ -165,6 +167,8 @@ public class MetadataManagerStats private final TimeStat getNewTableLayoutTime = new TimeStat(TimeUnit.NANOSECONDS); private final TimeStat beginCreateTableTime = new TimeStat(TimeUnit.NANOSECONDS); private final TimeStat finishCreateTableTime = new TimeStat(TimeUnit.NANOSECONDS); + private final TimeStat beginCreateVectorIndexTime = new TimeStat(TimeUnit.NANOSECONDS); + private final TimeStat finishCreateVectorIndexTime = new TimeStat(TimeUnit.NANOSECONDS); private final TimeStat getInsertLayoutTime = new TimeStat(TimeUnit.NANOSECONDS); private final TimeStat getStatisticsCollectionMetadataForWriteTime = new TimeStat(TimeUnit.NANOSECONDS); private final TimeStat getStatisticsCollectionMetadataTime = new TimeStat(TimeUnit.NANOSECONDS); @@ -648,6 +652,20 @@ public TimeStat getFinishCreateTableTime() return finishCreateTableTime; } + @Managed + @Nested + public TimeStat getBeginCreateVectorIndexTime() + { + return beginCreateVectorIndexTime; + } + + @Managed + @Nested + public TimeStat getFinishCreateVectorIndexTime() + { + return finishCreateVectorIndexTime; + } + @Managed @Nested public TimeStat getGetInsertLayoutTime() @@ -1357,6 +1375,18 @@ public void recordFinishCreateTableCall(long duration) finishCreateTableTime.add(duration, TimeUnit.NANOSECONDS); } + public void recordBeginCreateVectorIndexCall(long duration) + { + beginCreateVectorIndexCalls.incrementAndGet(); + beginCreateVectorIndexTime.add(duration, TimeUnit.NANOSECONDS); + } + + public void recordFinishCreateVectorIndexCall(long duration) + { + finishCreateVectorIndexCalls.incrementAndGet(); + finishCreateVectorIndexTime.add(duration, TimeUnit.NANOSECONDS); + } + public void recordGetInsertLayoutCall(long duration) { getInsertLayoutCalls.incrementAndGet(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StatsRecordingMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StatsRecordingMetadataManager.java index 1fa6ea8ffd3a1..b33ed39dac07f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/StatsRecordingMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StatsRecordingMetadataManager.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableLayoutFilterCoverage; @@ -195,6 +196,30 @@ public Optional finishCreateTable(Session session, Outp } } + @Override + public OutputTableHandle beginCreateVectorIndex(Session session, String catalogName, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + long startTime = System.nanoTime(); + try { + return delegate.beginCreateVectorIndex(session, catalogName, indexMetadata, layout, sourceTableName); + } + finally { + stats.recordBeginCreateVectorIndexCall(System.nanoTime() - startTime); + } + } + + @Override + public Optional finishCreateVectorIndex(Session session, OutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + long startTime = System.nanoTime(); + try { + return delegate.finishCreateVectorIndex(session, tableHandle, fragments, computedStatistics); + } + finally { + stats.recordFinishCreateVectorIndexCall(System.nanoTime() - startTime); + } + } + @Override public Optional getInsertLayout(Session session, TableHandle target) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/CreateVectorIndexAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/CreateVectorIndexAggregation.java new file mode 100644 index 0000000000000..9231cd1bc38e8 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/CreateVectorIndexAggregation.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.SliceState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; + +/** + * Dummy aggregate function for CREATE VECTOR INDEX planning. + * This function is never executed — the connector optimizer replaces + * the plan tree before execution. + */ +@AggregationFunction("create_vector_index") +public final class CreateVectorIndexAggregation +{ + private CreateVectorIndexAggregation() {} + + // 1-arg overloads: embedding only (no id) + + @InputFunction + public static void inputRealArray( + @AggregationState SliceState state, + @SqlType("array(real)") Block embedding) + { + } + + @InputFunction + public static void inputDoubleArray( + @AggregationState SliceState state, + @SqlType("array(double)") Block embedding) + { + } + + // 2-arg overloads: id + embedding (matches SQL syntax: ON table(id, embedding)) + + @InputFunction + @TypeParameter("T") + public static void inputRealArrayWithLongId( + @AggregationState SliceState state, + @SqlType("T") long id, + @SqlType("array(real)") Block embedding) + { + } + + @InputFunction + @TypeParameter("T") + public static void inputRealArrayWithDoubleId( + @AggregationState SliceState state, + @SqlType("T") double id, + @SqlType("array(real)") Block embedding) + { + } + + @InputFunction + @TypeParameter("T") + public static void inputRealArrayWithSliceId( + @AggregationState SliceState state, + @SqlType("T") Slice id, + @SqlType("array(real)") Block embedding) + { + } + + @InputFunction + @TypeParameter("T") + public static void inputDoubleArrayWithLongId( + @AggregationState SliceState state, + @SqlType("T") long id, + @SqlType("array(double)") Block embedding) + { + } + + @InputFunction + @TypeParameter("T") + public static void inputDoubleArrayWithDoubleId( + @AggregationState SliceState state, + @SqlType("T") double id, + @SqlType("array(double)") Block embedding) + { + } + + @InputFunction + @TypeParameter("T") + public static void inputDoubleArrayWithSliceId( + @AggregationState SliceState state, + @SqlType("T") Slice id, + @SqlType("array(double)") Block embedding) + { + } + + @CombineFunction + public static void combine( + @AggregationState SliceState state, + @AggregationState SliceState otherState) + { + } + + @OutputFunction(StandardTypes.VARCHAR) + public static void output(@AggregationState SliceState state, BlockBuilder out) + { + VARCHAR.writeSlice(out, Slices.utf8Slice("")); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 8b1f7c0809fba..d22a1d64310a1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -110,6 +110,7 @@ import com.facebook.presto.sql.tree.CreateSchema; import com.facebook.presto.sql.tree.CreateTable; import com.facebook.presto.sql.tree.CreateTableAsSelect; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Cube; import com.facebook.presto.sql.tree.Deallocate; @@ -1144,6 +1145,71 @@ protected Scope visitCreateTable(CreateTable node, Optional scope) return createAndAssignScope(node, scope); } + @Override + protected Scope visitCreateVectorIndex(CreateVectorIndex node, Optional scope) + { + QualifiedObjectName sourceTableName = createQualifiedObjectName(session, node, node.getTableName(), metadata); + if (!metadataResolver.tableExists(sourceTableName)) { + throw new SemanticException(MISSING_TABLE, node, "Source table '%s' does not exist", sourceTableName); + } + + QualifiedObjectName targetTable = createQualifiedObjectName(session, node, node.getIndexName(), metadata); + if (metadataResolver.tableExists(targetTable)) { + throw new SemanticException(TABLE_ALREADY_EXISTS, node, "Destination table '%s' already exists", targetTable); + } + + // Analyze the source table to build a proper scope with typed columns + // Use AllowAllAccessControl since we check permissions separately below + StatementAnalyzer analyzer = new StatementAnalyzer( + analysis, + metadata, + sqlParser, + new AllowAllAccessControl(), + session, + warningCollector); + + Table sourceTable = new Table(node.getTableName()); + Scope tableScope = analyzer.analyze(sourceTable, scope); + + // Validate that specified columns exist in the source table + TableHandle sourceTableHandle = metadataResolver.getTableHandle(sourceTableName).get(); + Map sourceColumns = metadataResolver.getColumnHandles(sourceTableHandle); + for (Identifier column : node.getColumns()) { + if (!sourceColumns.containsKey(column.getValue())) { + throw new SemanticException(MISSING_COLUMN, column, "Column '%s' does not exist in source table '%s'", column.getValue(), sourceTableName); + } + } + + // Analyze UPDATING FOR predicate (validates column references, types, etc.) + node.getUpdatingFor().ifPresent(where -> analyzeWhere(node, tableScope, where)); + + validateProperties(node.getProperties(), scope); + + Map allProperties = mapFromProperties(node.getProperties()); + + // user must have read permission on the source table to create a vector index + Multimap tableColumnMap = ImmutableMultimap.builder() + .putAll(sourceTableName, sourceColumns.keySet().stream() + .map(column -> new Subfield(column, ImmutableList.of())) + .collect(toImmutableSet())) + .build(); + analysis.addTableColumnAndSubfieldReferences(accessControl, session.getIdentity(), + session.getTransactionId(), session.getAccessControlContext(), tableColumnMap, tableColumnMap); + + analysis.addAccessControlCheckForTable(TABLE_CREATE, + new AccessControlInfoForTable(accessControl, session.getIdentity(), + session.getTransactionId(), session.getAccessControlContext(), targetTable)); + + analysis.setCreateVectorIndexAnalysis(new Analysis.CreateVectorIndexAnalysis( + sourceTableName, + targetTable, + node.getColumns(), + allProperties, + node.getUpdatingFor())); + + return createAndAssignScope(node, scope, Field.newUnqualified(node.getLocation(), "result", VARCHAR)); + } + @Override protected Scope visitProperty(Property node, Optional scope) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index fe73971051712..7c04d578fc832 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; @@ -29,9 +30,12 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.DeleteNode; +import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; @@ -47,10 +51,14 @@ import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TableWriterNode.DeleteHandle; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.TableStatisticsMetadata; +import com.facebook.presto.sql.ExpressionFormatter; import com.facebook.presto.sql.analyzer.Analysis; +import com.facebook.presto.sql.analyzer.ExpressionAnalyzer; +import com.facebook.presto.sql.analyzer.ExpressionTreeUtils; import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; @@ -66,10 +74,13 @@ import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CreateTableAsSelect; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.ExplainFormat; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; @@ -89,6 +100,7 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -97,10 +109,13 @@ import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeUtils.writeNativeValue; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.PartitionedTableWritePolicy.MULTIPLE_WRITERS_PER_PARTITION_ALLOWED; +import static com.facebook.presto.spi.StandardErrorCode.COLUMN_NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; @@ -113,6 +128,8 @@ import static com.facebook.presto.sql.TemporaryTableUtil.splitIntoPartialAndFinal; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; @@ -215,6 +232,10 @@ else if (statement instanceof RefreshMaterializedView) { checkState(analysis.getRefreshMaterializedViewAnalysis().isPresent(), "RefreshMaterializedView analysis is missing"); return createRefreshMaterializedViewPlan(analysis, (RefreshMaterializedView) statement); } + else if (statement instanceof CreateVectorIndex) { + checkState(analysis.getCreateVectorIndexAnalysis().isPresent(), "CreateVectorIndex analysis is missing"); + return createVectorIndexPlan(analysis, (CreateVectorIndex) statement); + } else { throw new PrestoException(NOT_SUPPORTED, "Unsupported statement type " + statement.getClass().getSimpleName()); } @@ -394,6 +415,208 @@ private RelationPlan createTableCreationPlan(Analysis analysis, Query query) statisticsMetadata); } + private RelationPlan createVectorIndexPlan(Analysis analysis, CreateVectorIndex statement) + { + Analysis.CreateVectorIndexAnalysis vectorIndexAnalysis = analysis.getCreateVectorIndexAnalysis().get(); + + QualifiedObjectName sourceTableName = vectorIndexAnalysis.getSourceTableName(); + QualifiedObjectName targetTableName = vectorIndexAnalysis.getTargetTableName(); + + // Resolve source table handle and metadata + TableHandle sourceTableHandle = metadata.getHandleVersion(session, sourceTableName, Optional.empty()) + .orElseThrow(() -> new PrestoException(NOT_FOUND, "Source table does not exist: " + sourceTableName)); + TableMetadata sourceMetadata = metadata.getTableMetadata(session, sourceTableHandle); + Map sourceColumnHandles = metadata.getColumnHandles(session, sourceTableHandle); + + // Index columns (args to create_vector_index aggregate) + List indexColumnNames = vectorIndexAnalysis.getColumns().stream() + .map(Identifier::getValue) + .collect(toImmutableList()); + + // Collect all columns needed: index columns + filter columns from UPDATING FOR + Set allColumnNames = new LinkedHashSet<>(indexColumnNames); + vectorIndexAnalysis.getUpdatingFor().ifPresent(expr -> + ExpressionTreeUtils.extractExpressions(ImmutableList.of(expr), Identifier.class) + .stream() + .map(Identifier::getValue) + .forEach(allColumnNames::add)); + + // Build scan variables for all referenced columns + ImmutableList.Builder scanVariablesBuilder = ImmutableList.builder(); + ImmutableMap.Builder scanAssignmentsBuilder = ImmutableMap.builder(); + Map columnToVariable = new LinkedHashMap<>(); + + for (String columnName : allColumnNames) { + ColumnHandle handle = sourceColumnHandles.get(columnName); + if (handle == null) { + throw new PrestoException(COLUMN_NOT_FOUND, "Column not found: " + columnName); + } + ColumnMetadata colMeta = sourceMetadata.getColumn(columnName); + VariableReferenceExpression variable = variableAllocator.newVariable( + getSourceLocation(statement), columnName, colMeta.getType()); + scanVariablesBuilder.add(variable); + scanAssignmentsBuilder.put(variable, handle); + columnToVariable.put(columnName, variable); + } + + // Build TableScanNode + PlanNode planNode = new TableScanNode( + getSourceLocation(statement), + idAllocator.getNextId(), + sourceTableHandle, + scanVariablesBuilder.build(), + scanAssignmentsBuilder.build(), + TupleDomain.all(), + TupleDomain.all(), + Optional.empty()); + + // Add FilterNode for UPDATING FOR predicate (WHERE clause equivalent) + if (vectorIndexAnalysis.getUpdatingFor().isPresent()) { + Expression updatingForExpr = vectorIndexAnalysis.getUpdatingFor().get(); + // Rewrite Identifier references to SymbolReferences matching scan variables + Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() { + @Override + public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter treeRewriter) + { + VariableReferenceExpression var = columnToVariable.get(node.getValue()); + if (var != null) { + return createSymbolReference(var); + } + return node; + } + }, updatingForExpr); + + RowExpression filterPredicate = rowExpression(rewritten, new SqlPlannerContext(0), analysis); + + planNode = new FilterNode( + getSourceLocation(statement), + idAllocator.getNextId(), + planNode, + filterPredicate); + } + + // Build AggregationNode: create_vector_index(id, embedding) or create_vector_index(embedding) + List aggArgs = indexColumnNames.stream() + .map(name -> (RowExpression) columnToVariable.get(name)) + .collect(toImmutableList()); + + List argTypes = aggArgs.stream() + .map(RowExpression::getType) + .collect(toImmutableList()); + + FunctionHandle functionHandle = metadata.getFunctionAndTypeManager() + .lookupFunction("create_vector_index", fromTypes(argTypes)); + + VariableReferenceExpression resultVar = variableAllocator.newVariable("result", VARCHAR); + + CallExpression aggCall = new CallExpression( + getSourceLocation(statement), + "create_vector_index", + functionHandle, + VARCHAR, + aggArgs); + + AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation( + aggCall, + Optional.empty(), + Optional.empty(), + false, + Optional.empty()); + + planNode = new AggregationNode( + getSourceLocation(statement), + idAllocator.getNextId(), + planNode, + ImmutableMap.of(resultVar, aggregation), + AggregationNode.globalAggregation(), + ImmutableList.of(), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // Build target table metadata with properties + ConnectorId connectorId = getConnectorIdOrThrow(session, metadata, targetTableName.getCatalogName()); + + // Vector index properties are not registered with TablePropertyManager, so we evaluate + // them directly. The connector's plan optimizer is responsible for validating these properties + // when it rewrites the CreateVectorIndexReference plan node. + Map properties = new LinkedHashMap<>(); + for (Map.Entry entry : vectorIndexAnalysis.getProperties().entrySet()) { + properties.put(entry.getKey(), evaluatePropertyExpression(entry.getValue(), analysis)); + } + vectorIndexAnalysis.getUpdatingFor().ifPresent(expr -> + properties.put("updating_for", ExpressionFormatter.formatExpression(expr, Optional.empty()))); + + // Target table: single VARCHAR column (aggregate output) + List targetColumns = ImmutableList.of( + ColumnMetadata.builder().setName("result").setType(VARCHAR).build()); + + ConnectorTableMetadata targetTableMetadata = new ConnectorTableMetadata( + toSchemaTableName(targetTableName), + targetColumns, + properties, + Optional.empty()); + + // Build RelationPlan for the aggregation output + List fields = ImmutableList.of( + Field.newUnqualified(statement.getLocation(), "result", VARCHAR)); + Scope scope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(fields)) + .build(); + RelationPlan sourcePlan = new RelationPlan(planNode, scope, ImmutableList.of(resultVar)); + + TableWriterNode.CreateVectorIndexReference writerTarget = new TableWriterNode.CreateVectorIndexReference( + connectorId, + targetTableMetadata, + Optional.empty(), + Optional.empty(), + toSchemaTableName(sourceTableName)); + + // Build plan manually with VARCHAR output (not BIGINT row count) + // so the connector optimizer can return a single VARCHAR result + TableFinishNode commitNode = new TableFinishNode( + planNode.getSourceLocation(), + idAllocator.getNextId(), + new TableWriterNode( + planNode.getSourceLocation(), + idAllocator.getNextId(), + planNode, + Optional.of(writerTarget), + variableAllocator.newVariable("rows", BIGINT), + variableAllocator.newVariable("fragments", VARBINARY), + variableAllocator.newVariable("commitcontext", VARBINARY), + ImmutableList.of(resultVar), + ImmutableList.of("result"), + ImmutableSet.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(Boolean.FALSE)), + Optional.of(writerTarget), + variableAllocator.newVariable("result", VARCHAR), + Optional.empty(), + Optional.empty(), + Optional.empty()); + + return new RelationPlan(commitNode, scope, commitNode.getOutputVariables()); + } + + private Object evaluatePropertyExpression(Expression expression, Analysis analysis) + { + ExpressionAnalyzer analyzer = ExpressionAnalyzer.createConstantAnalyzer( + metadata, session, analysis.getParameters(), WarningCollector.NOOP, false); + analyzer.analyze(expression, Scope.create()); + + Type type = analyzer.getExpressionTypes().get(NodeRef.of(expression)); + Object value = evaluateConstantExpression(expression, type, metadata, session, analysis.getParameters()); + + // Convert native representation (e.g., Slice, Block) to Java object (e.g., String, List) + BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); + writeNativeValue(type, blockBuilder, value); + return type.getObjectValue(session.getSqlFunctionProperties(), blockBuilder, 0); + } + private RelationPlan createRefreshMaterializedViewPlan(Analysis analysis, RefreshMaterializedView refreshMaterializedViewStatement) { Analysis.RefreshMaterializedViewAnalysis viewAnalysis = analysis.getRefreshMaterializedViewAnalysis().get(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index f6bfcc5d75e58..21ea729936869 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -82,6 +82,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SAMPLE_PERCENTAGE_OUT_OF_RANGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_ALREADY_EXISTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_COLUMN_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_IMPLEMENTATION_ERROR; @@ -2399,4 +2400,48 @@ public void testInvalidMerge() assertFails(NOT_SUPPORTED, "line 1:1: Merging into materialized views is not supported", "MERGE INTO mv1 USING t1 ON mv1.a = t1.a WHEN MATCHED THEN UPDATE SET id = bar.id + 1"); } + + @Test + public void testCreateVectorIndex() + { + // basic success cases + analyze("CREATE VECTOR INDEX test_index ON t1(a, b)"); + analyze("CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = 'val1')"); + analyze("CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = 'val1', p2 = 'val2')"); + + // with UPDATING FOR clause + analyze("CREATE VECTOR INDEX test_index ON t1(a, b) UPDATING FOR a > 10"); + analyze("CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = 'val1') UPDATING FOR a BETWEEN 1 AND 100"); + + // single column + analyze("CREATE VECTOR INDEX test_index ON t1(a)"); + + // source table does not exist + assertFails(MISSING_TABLE, ".*Source table '.*' does not exist", + "CREATE VECTOR INDEX test_index ON nonexistent_table(a, b)"); + + // destination table already exists (using an existing table name as the index name) + assertFails(TABLE_ALREADY_EXISTS, ".*already exists", + "CREATE VECTOR INDEX t1 ON t2(a, b)"); + + // column does not exist in source table + assertFails(MISSING_COLUMN, ".*Column 'unknown' does not exist in source table '.*'", + "CREATE VECTOR INDEX test_index ON t1(a, unknown)"); + assertFails(MISSING_COLUMN, ".*Column 'nonexistent' does not exist in source table '.*'", + "CREATE VECTOR INDEX test_index ON t1(nonexistent)"); + + // duplicate properties + assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", + "CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = 'v1', p2 = 'v2', p1 = 'v3')"); + assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", + "CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = 'v1', \"p1\" = 'v2')"); + + // unresolved property value + assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", + "CREATE VECTOR INDEX test_index ON t1(a, b) WITH (p1 = y)"); + + // UPDATING FOR with invalid column reference + assertFails(MISSING_ATTRIBUTE, ".*", + "CREATE VECTOR INDEX test_index ON t1(a, b) UPDATING FOR nonexistent_col > 10"); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index 5540734c7ed47..8e7be0e171b5c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -502,6 +502,22 @@ default Optional finishCreateTable(ConnectorSession ses throw new PrestoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata beginCreateTable() is implemented without finishCreateTable()"); } + /** + * Begin the atomic creation of a vector index with data. + */ + default ConnectorOutputTableHandle beginCreateVectorIndex(ConnectorSession session, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support creating vector indexes"); + } + + /** + * Finish a vector index creation with data after the data is written. + */ + default Optional finishCreateVectorIndex(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support creating vector indexes"); + } + /** * Start a SELECT/UPDATE/INSERT/DELETE query. This notification is triggered after the planning phase completes. */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 63e0f082b788d..7e7c6c99bf4e7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -450,6 +450,22 @@ public Optional finishCreateTable(ConnectorSession sess } } + @Override + public ConnectorOutputTableHandle beginCreateVectorIndex(ConnectorSession session, ConnectorTableMetadata indexMetadata, Optional layout, SchemaTableName sourceTableName) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.beginCreateVectorIndex(session, indexMetadata, layout, sourceTableName); + } + } + + @Override + public Optional finishCreateVectorIndex(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.finishCreateVectorIndex(session, tableHandle, fragments, computedStatistics); + } + } + @Override public void beginQuery(ConnectorSession session) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java index 3fcf43a488c8e..d11a28383bf51 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java @@ -698,4 +698,67 @@ public String toString() return procedureName.toString(); } } + + public static class CreateVectorIndexReference + extends WriterTarget + { + private final ConnectorId connectorId; + private final ConnectorTableMetadata tableMetadata; + private final Optional layout; + private final Optional> columns; + private final SchemaTableName sourceTableName; + + public CreateVectorIndexReference( + ConnectorId connectorId, + ConnectorTableMetadata tableMetadata, + Optional layout, + Optional> columns, + SchemaTableName sourceTableName) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.tableMetadata = requireNonNull(tableMetadata, "tableMetadata is null"); + this.layout = requireNonNull(layout, "layout is null"); + this.columns = requireNonNull(columns, "columns is null"); + this.sourceTableName = requireNonNull(sourceTableName, "sourceTableName is null"); + } + + @Override + public ConnectorId getConnectorId() + { + return connectorId; + } + + public ConnectorTableMetadata getTableMetadata() + { + return tableMetadata; + } + + public Optional getLayout() + { + return layout; + } + + @Override + public SchemaTableName getSchemaTableName() + { + return tableMetadata.getTable(); + } + + @Override + public String toString() + { + return connectorId + "." + tableMetadata.getTable(); + } + + @Override + public Optional> getOutputColumns() + { + return columns; + } + + public SchemaTableName getSourceTableName() + { + return sourceTableName; + } + } }