diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index 54fc2dd4cc99e..a2a6d91c6b914 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -168,7 +168,11 @@ statement SET updateAssignment (',' updateAssignment)* (WHERE where=booleanExpression)? #update | MERGE INTO qualifiedName (AS? identifier)? - USING relation ON expression mergeCase+ #mergeInto + USING relation ON expression mergeCase+ #mergeInto + | CREATE VECTOR INDEX identifier ON TABLE qualifiedName + '(' identifier (',' identifier)? ')' + (WITH properties)? + (UPDATING FOR booleanExpression)? #createVectorIndex ; query @@ -194,7 +198,7 @@ likeClause ; properties - : '(' property (',' property)* ')' + : '(' property (',' property)* ','? ')' ; property @@ -704,7 +708,7 @@ nonReserved | FETCH | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS | HOUR - | IF | IGNORE | INCLUDING | INPUT | INTERVAL | INVOKER | IO | ISOLATION + | IF | IGNORE | INCLUDING | INDEX | INPUT | INTERVAL | INVOKER | IO | ISOLATION | JSON | KEEP | KEY | LANGUAGE | LAST | LATERAL | LEVEL | LIMIT | LOGICAL @@ -716,8 +720,8 @@ nonReserved | SCHEMA | SCHEMAS | SECOND | SECURITY | SERIALIZABLE | SESSION | SET | SETS | SNAPSHOT | SNAPSHOTS | SQL | SHOW | SOME | START | STATS | SUBSTRING | SYSTEM | SYSTEM_TIME | SYSTEM_VERSION | TABLES | TABLESAMPLE | TAG | TEMPORARY | TEXT | TIME | TIMESTAMP | TO | TRANSACTION | TRUNCATE | TRY_CAST | TYPE - | UNBOUNDED | UNCOMMITTED | UNIQUE | UPDATE | USE | USER - | VALIDATE | VERBOSE | VERSION | VIEW + | UNBOUNDED | UNCOMMITTED | UNIQUE | UPDATE | UPDATING | USE | USER + | VALIDATE | VECTOR | VERBOSE | VERSION | VIEW | WORK | WRITE | YEAR | ZONE @@ -813,6 +817,7 @@ IF: 'IF'; IGNORE: 'IGNORE'; IN: 'IN'; INCLUDING: 'INCLUDING'; +INDEX: 'INDEX'; INNER: 'INNER'; INPUT: 'INPUT'; INSERT: 'INSERT'; @@ -942,10 +947,12 @@ UNIQUE: 'UNIQUE'; UNNEST: 'UNNEST'; UPDATE: 'UPDATE'; USE: 'USE'; +UPDATING: 'UPDATING'; USER: 'USER'; USING: 'USING'; VALIDATE: 'VALIDATE'; VALUES: 'VALUES'; +VECTOR: 'VECTOR'; VERBOSE: 'VERBOSE'; VERSION: 'VERSION'; VIEW: 'VIEW'; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index d7d7c3a1e0cef..97aa2dce6701c 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -35,6 +35,7 @@ import com.facebook.presto.sql.tree.CreateTable; import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateTag; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Deallocate; import com.facebook.presto.sql.tree.Delete; @@ -1285,6 +1286,36 @@ protected Void visitCreateTable(CreateTable node, Integer indent) return null; } + @Override + protected Void visitCreateVectorIndex(CreateVectorIndex node, Integer indent) + { + builder.append("CREATE VECTOR INDEX "); + builder.append(formatName(node.getIndexName())); + builder.append(" ON TABLE "); + builder.append(formatName(node.getTableName())); + builder.append(" ("); + builder.append(node.getColumns().stream() + .map(Formatter::formatName) + .collect(joining(", "))); + builder.append(")"); + + if (!node.getProperties().isEmpty()) { + builder.append("\nWITH ("); + builder.append(node.getProperties().stream() + .map(property -> formatName(property.getName()) + " = " + + formatExpression(property.getValue(), parameters)) + .collect(joining(", "))); + builder.append(")"); + } + + node.getUpdatingFor().ifPresent(updatingFor -> { + builder.append("\nUPDATING FOR "); + builder.append(formatExpression(updatingFor, parameters)); + }); + + return null; + } + private String formatPropertiesMultiLine(List properties) { if (properties.isEmpty()) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index b1ad2a9c290d2..d2e22ed48158f 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -48,6 +48,7 @@ import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateTag; import com.facebook.presto.sql.tree.CreateType; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Cube; import com.facebook.presto.sql.tree.CurrentTime; @@ -368,6 +369,37 @@ public Node visitCreateTable(SqlBaseParser.CreateTableContext context) comment); } + @Override + public Node visitCreateVectorIndex(SqlBaseParser.CreateVectorIndexContext context) + { + Identifier indexName = (Identifier) visit(context.identifier(0)); + QualifiedName tableName = getQualifiedName(context.qualifiedName()); + + // Columns start from identifier(1) onwards + List columns = context.identifier().stream() + .skip(1) // Skip index name + .map(id -> (Identifier) visit(id)) + .collect(toImmutableList()); + + Optional updatingFor = Optional.empty(); + if (context.UPDATING() != null) { + updatingFor = Optional.of((Expression) visit(context.booleanExpression())); + } + + List properties = ImmutableList.of(); + if (context.properties() != null) { + properties = visit(context.properties().property(), Property.class); + } + + return new CreateVectorIndex( + getLocation(context), + indexName, + tableName, + columns, + updatingFor, + properties); + } + @Override public Node visitCreateType(SqlBaseParser.CreateTypeContext context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index 40ae3d7398bc1..e5d4123ba95b6 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -582,6 +582,11 @@ protected R visitCreateTable(CreateTable node, C context) return visitStatement(node, context); } + protected R visitCreateVectorIndex(CreateVectorIndex node, C context) + { + return visitStatement(node, context); + } + protected R visitCreateType(CreateType node, C context) { return visitStatement(node, context); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateVectorIndex.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateVectorIndex.java new file mode 100644 index 0000000000000..76c6a230051c2 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/CreateVectorIndex.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class CreateVectorIndex + extends Statement +{ + private final Identifier indexName; + private final QualifiedName tableName; + private final List columns; + private final Optional updatingFor; + private final List properties; + + public CreateVectorIndex( + Identifier indexName, + QualifiedName tableName, + List columns, + Optional updatingFor, + List properties) + { + this(Optional.empty(), indexName, tableName, columns, updatingFor, properties); + } + + public CreateVectorIndex( + NodeLocation location, + Identifier indexName, + QualifiedName tableName, + List columns, + Optional updatingFor, + List properties) + { + this(Optional.of(location), indexName, tableName, columns, updatingFor, properties); + } + + private CreateVectorIndex( + Optional location, + Identifier indexName, + QualifiedName tableName, + List columns, + Optional updatingFor, + List properties) + { + super(location); + this.indexName = requireNonNull(indexName, "indexName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + this.updatingFor = requireNonNull(updatingFor, "updatingFor is null"); + this.properties = ImmutableList.copyOf(requireNonNull(properties, "properties is null")); + } + + public Identifier getIndexName() + { + return indexName; + } + + public QualifiedName getTableName() + { + return tableName; + } + + public List getColumns() + { + return columns; + } + + public Optional getUpdatingFor() + { + return updatingFor; + } + + public List getProperties() + { + return properties; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitCreateVectorIndex(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(indexName); + children.addAll(columns); + updatingFor.ifPresent(children::add); + children.addAll(properties); + return children.build(); + } + + @Override + public int hashCode() + { + return Objects.hash(indexName, tableName, columns, updatingFor, properties); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + CreateVectorIndex o = (CreateVectorIndex) obj; + return Objects.equals(indexName, o.indexName) && + Objects.equals(tableName, o.tableName) && + Objects.equals(columns, o.columns) && + Objects.equals(updatingFor, o.updatingFor) && + Objects.equals(properties, o.properties); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("indexName", indexName) + .add("tableName", tableName) + .add("columns", columns) + .add("updatingFor", updatingFor) + .add("properties", properties) + .toString(); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java index 8ad4a3b381df6..8a386cc85715b 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java @@ -590,6 +590,20 @@ protected R visitCreateTable(CreateTable node, C context) return null; } + @Override + protected R visitCreateVectorIndex(CreateVectorIndex node, C context) + { + process(node.getIndexName(), context); + for (Identifier column : node.getColumns()) { + process(column, context); + } + node.getUpdatingFor().ifPresent(updatingFor -> process(updatingFor, context)); + for (Property property : node.getProperties()) { + process(property, context); + } + return null; + } + @Override protected R visitStartTransaction(StartTransaction node, C context) { diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index 01a60fce95955..3d6cbbbfad9bb 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -44,6 +44,7 @@ import com.facebook.presto.sql.tree.CreateTable; import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateTag; +import com.facebook.presto.sql.tree.CreateVectorIndex; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Cube; import com.facebook.presto.sql.tree.CurrentTime; @@ -1379,6 +1380,170 @@ public void testCreateTableWithNotNull() Optional.empty())); } + @Test + public void testCreateVectorIndex() + { + // Basic CREATE VECTOR INDEX + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(a, b)", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("a"), identifier("b")), + Optional.empty(), + ImmutableList.of())); + + // Single column + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(a)", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("a")), + Optional.empty(), + ImmutableList.of())); + + // With qualified table name + assertStatement("CREATE VECTOR INDEX idx ON TABLE catalog.schema.t(a, b)", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("catalog", "schema", "t"), + ImmutableList.of(identifier("a"), identifier("b")), + Optional.empty(), + ImmutableList.of())); + + // With qualified table name single column + assertStatement("CREATE VECTOR INDEX idx ON TABLE catalog.schema.t(a)", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("catalog", "schema", "t"), + ImmutableList.of(identifier("a")), + Optional.empty(), + ImmutableList.of())); + + // With single property + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(c) WITH (index_type = 'ivf')", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("c")), + Optional.empty(), + ImmutableList.of( + new Property(identifier("index_type"), new StringLiteral("ivf"))))); + + // With multiple properties + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(c) WITH (index_type = 'ivf', metric = 'cosine')", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("c")), + Optional.empty(), + ImmutableList.of( + new Property(identifier("index_type"), new StringLiteral("ivf")), + new Property(identifier("metric"), new StringLiteral("cosine"))))); + + // With trailing comma in properties (grammar allows it) + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(c) WITH (index_type = 'ivf',)", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("c")), + Optional.empty(), + ImmutableList.of( + new Property(identifier("index_type"), new StringLiteral("ivf"))))); + + // With UPDATING FOR equality + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(c) UPDATING FOR ds = '2024-01-01'", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("c")), + Optional.of(new ComparisonExpression( + EQUAL, + new Identifier("ds"), + new StringLiteral("2024-01-01"))), + ImmutableList.of())); + + // With UPDATING FOR BETWEEN expression + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(a, b) UPDATING FOR ds BETWEEN '2024-01-01' AND '2024-01-31'", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("a"), identifier("b")), + Optional.of(new BetweenPredicate( + new Identifier("ds"), + new StringLiteral("2024-01-01"), + new StringLiteral("2024-01-31"))), + ImmutableList.of())); + + // With properties and UPDATING FOR + assertStatement("CREATE VECTOR INDEX idx ON TABLE t(c) WITH (index_type = 'ivf') UPDATING FOR ds = '2024-01-01'", + new CreateVectorIndex( + identifier("idx"), + QualifiedName.of("t"), + ImmutableList.of(identifier("c")), + Optional.of(new ComparisonExpression( + EQUAL, + new Identifier("ds"), + new StringLiteral("2024-01-01"))), + ImmutableList.of( + new Property(identifier("index_type"), new StringLiteral("ivf"))))); + + // Full example with all clauses + assertStatement("CREATE VECTOR INDEX my_index ON TABLE catalog.schema.t(id, embedding) WITH (index_type = 'ivf_rabitq4', distance_metric = 'cosine') UPDATING FOR ds BETWEEN '2024-01-01' AND '2024-01-31'", + new CreateVectorIndex( + identifier("my_index"), + QualifiedName.of("catalog", "schema", "t"), + ImmutableList.of(identifier("id"), identifier("embedding")), + Optional.of(new BetweenPredicate( + new Identifier("ds"), + new StringLiteral("2024-01-01"), + new StringLiteral("2024-01-31"))), + ImmutableList.of( + new Property(identifier("index_type"), new StringLiteral("ivf_rabitq4")), + new Property(identifier("distance_metric"), new StringLiteral("cosine"))))); + + // Negative tests + + // Missing index name + assertInvalidStatement("CREATE VECTOR INDEX ON TABLE t(a)", "mismatched input 'ON'.*"); + + // Missing column list + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t", "mismatched input ''.*"); + + // Empty column list + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t()", "mismatched input '\\)'.*"); + + // Missing table name + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE (a)", "mismatched input '\\('.*"); + + // Missing ON keyword + assertInvalidStatement("CREATE VECTOR INDEX idx TABLE t(a)", "mismatched input 'TABLE'.*"); + + // Missing VECTOR keyword + assertInvalidStatement("CREATE INDEX idx ON TABLE t(a)", "mismatched input 'INDEX'.*"); + + // Missing INDEX keyword + assertInvalidStatement("CREATE VECTOR idx ON TABLE t(a)", "mismatched input 'idx'.*"); + + // Missing TABLE keyword + assertInvalidStatement("CREATE VECTOR INDEX idx ON t(a)", "mismatched input 't'.*"); + + // WITH without parentheses + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t(a) WITH index_type = 'ivf'", + "mismatched input 'index_type'.*"); + + // Invalid WHERE clause instead of UPDATING FOR + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t(a) WHERE a > 1", + "mismatched input 'WHERE'.*"); + + // UPDATING without FOR + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t(a) UPDATING ds = '2024-01-01'", + "mismatched input 'ds'.*"); + + // UPDATING FOR without expression + assertInvalidStatement("CREATE VECTOR INDEX idx ON TABLE t(a) UPDATING FOR", + "mismatched input ''.*"); + } + @Test public void testCreateTableAsSelect() { diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java index 3dea840f8a7fc..57b11d878c81b 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParserErrorHandling.java @@ -81,9 +81,9 @@ public Object[][] getStatements() {"select foo(DISTINCT ,1)", "line 1:21: mismatched input ','. Expecting: "}, {"CREATE TABLE foo () AS (VALUES 1)", - "line 1:19: mismatched input ')'. Expecting: 'FUNCTION', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'TEMPORARY', 'TYPE', 'VIEW'"}, + "line 1:19: mismatched input ')'. Expecting: 'FUNCTION', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'TEMPORARY', 'TYPE', 'VECTOR', 'VIEW'"}, {"CREATE TABLE foo (*) AS (VALUES 1)", - "line 1:19: mismatched input '*'. Expecting: 'FUNCTION', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'TEMPORARY', 'TYPE', 'VIEW'"}, + "line 1:19: mismatched input '*'. Expecting: 'FUNCTION', 'MATERIALIZED', 'OR', 'ROLE', 'SCHEMA', 'TABLE', 'TEMPORARY', 'TYPE', 'VECTOR', 'VIEW'"}, {"SELECT grouping(a+2) FROM (VALUES (1)) AS t (a) GROUP BY a+2", "line 1:18: mismatched input '+'. Expecting: ')', ','"}, {"SELECT x() over (ROWS select) FROM t",