diff --git a/api/src/main/java/org/apache/gravitino/rel/TableChange.java b/api/src/main/java/org/apache/gravitino/rel/TableChange.java index cb6c741dea5..c0b04f58857 100644 --- a/api/src/main/java/org/apache/gravitino/rel/TableChange.java +++ b/api/src/main/java/org/apache/gravitino/rel/TableChange.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import java.util.Arrays; +import java.util.Map; import java.util.Objects; import java.util.Optional; import org.apache.gravitino.annotation.Evolving; @@ -453,6 +454,20 @@ static TableChange addIndex(IndexType type, String name, String[][] fieldNames) return new AddIndex(type, name, fieldNames); } + /** + * Create a TableChange for adding an index. + * + * @param type The type of the index. + * @param name The name of the index. + * @param fieldNames The field names of the index. + * @param properties The properties of the index. + * @return A TableChange for the add index. + */ + static TableChange addIndex( + IndexType type, String name, String[][] fieldNames, Map properties) { + return new AddIndex(type, name, fieldNames, properties); + } + /** * Create a TableChange for deleting an index. * @@ -747,15 +762,29 @@ final class AddIndex implements TableChange { private final String[][] fieldNames; + private final Map properties; + /** * @param type The type of the index. * @param name The name of the index. * @param fieldNames The field names of the index. */ public AddIndex(IndexType type, String name, String[][] fieldNames) { + this(type, name, fieldNames, Map.of()); + } + + /** + * @param type The type of the index. + * @param name The name of the index. + * @param fieldNames The field names of the index. + * @param properties The properties of the index. + */ + public AddIndex( + IndexType type, String name, String[][] fieldNames, Map properties) { this.type = type; this.name = name; this.fieldNames = fieldNames; + this.properties = properties; } /** @@ -779,6 +808,13 @@ public String[][] getFieldNames() { return fieldNames; } + /** + * @return The properties of the index. + */ + public Map getProperties() { + return properties; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -786,13 +822,15 @@ public boolean equals(Object o) { AddIndex addIndex = (AddIndex) o; return type == addIndex.type && Objects.equals(name, addIndex.name) - && Arrays.deepEquals(fieldNames, addIndex.fieldNames); + && Arrays.deepEquals(fieldNames, addIndex.fieldNames) + && Objects.equals(properties, addIndex.properties); } @Override public int hashCode() { int result = Objects.hash(type, name); result = 31 * result + Arrays.hashCode(fieldNames); + result = 31 * result + Objects.hashCode(properties); return result; } } diff --git a/api/src/main/java/org/apache/gravitino/rel/indexes/Index.java b/api/src/main/java/org/apache/gravitino/rel/indexes/Index.java index fd3beb0e67f..b8decc5b149 100644 --- a/api/src/main/java/org/apache/gravitino/rel/indexes/Index.java +++ b/api/src/main/java/org/apache/gravitino/rel/indexes/Index.java @@ -19,6 +19,7 @@ package org.apache.gravitino.rel.indexes; +import java.util.Map; import org.apache.gravitino.annotation.Evolving; /** @@ -44,6 +45,13 @@ public interface Index { */ String[][] fieldNames(); + /** + * @return The properties of the index. + */ + default Map properties() { + return Map.of(); + } + /** * The enum IndexType defines the type of the index. Currently, PRIMARY_KEY and UNIQUE_KEY are * supported. @@ -102,15 +110,44 @@ enum IndexType { * Currently, this type is only applicable to Lance. */ VECTOR, - /** IVF_FLAT (Inverted File with Flat quantization) is an indexing method used for efficient */ + + /** + * IVF_FLAT (Inverted File with Flat Quantization) is an indexing method used for efficient + * approximate nearest neighbor search in high-dimensional vector spaces. It stores the original + * vectors without any quantization. + */ IVF_FLAT, - /** IVF_SQ (Inverted File with Scalar Quantization) is an indexing method used for efficient */ + /** + * IVF_SQ (Inverted File with Scalar Quantization) is an indexing method used for efficient + * approximate nearest neighbor search in high-dimensional vector spaces. It applies scalar + * quantization to reduce the storage size of the vectors. + */ IVF_SQ, - /** IVF_PQ (Inverted File with Product Quantization) is an indexing method used for efficient */ + /** + * IVF_PQ (Inverted File with Product Quantization) is an indexing method used for efficient + * approximate nearest neighbor search in high-dimensional vector spaces. It applies product + * quantization to compress the vectors into smaller codes for faster search and reduced + * storage. + */ IVF_PQ, - /** IVF_HNSW_FLAT */ + + /** + * IVF_HNSW_SQ is an indexing method that combines Inverted File (IVF) with Hierarchical + * Navigable Small World (HNSW) graphs and Scalar Quantization (SQ) for efficient approximate + * nearest neighbor search in high-dimensional vector spaces. + */ IVF_HNSW_SQ, - /** IVF_HNSW_PQ */ - IVF_HNSW_PQ; + /** + * IVF_HNSW_PQ is an indexing method that combines Inverted File (IVF) with Hierarchical + * Navigable Small World (HNSW) graphs and Product Quantization (PQ) for efficient approximate + * nearest neighbor search in high-dimensional vector spaces. + */ + IVF_HNSW_PQ, + /** + * FTS index is a data structure used for efficient storage and retrieval of strings, enabling + * fast prefix-based searches and pattern matching. Currently, this type is only applicable to + * Lance. + */ + FTS; } } diff --git a/api/src/main/java/org/apache/gravitino/rel/indexes/Indexes.java b/api/src/main/java/org/apache/gravitino/rel/indexes/Indexes.java index c6303e7938b..ca7bc6327aa 100644 --- a/api/src/main/java/org/apache/gravitino/rel/indexes/Indexes.java +++ b/api/src/main/java/org/apache/gravitino/rel/indexes/Indexes.java @@ -20,6 +20,7 @@ import com.google.common.base.Objects; import java.util.Arrays; +import java.util.Map; /** Helper methods to create index to pass into Apache Gravitino. */ public class Indexes { @@ -76,6 +77,28 @@ public static Index of(Index.IndexType indexType, String name, String[][] fieldN .build(); } + /** + * Create an index. + * + * @param indexType The type of the index + * @param name The name of the index + * @param fieldNames The field names under the table contained in the index. + * @param properties The properties of the index. + * @return The index to be created. + */ + public static Index of( + Index.IndexType indexType, + String name, + String[][] fieldNames, + Map properties) { + return IndexImpl.builder() + .withIndexType(indexType) + .withName(name) + .withFieldNames(fieldNames) + .withProperties(properties) + .build(); + } + /** The user side implementation of the index. */ public static final class IndexImpl implements Index { private final IndexType indexType; @@ -84,6 +107,7 @@ public static final class IndexImpl implements Index { private final String[][] fieldNames; + private final Map properties; /** * The constructor of the index. * @@ -91,10 +115,12 @@ public static final class IndexImpl implements Index { * @param name The name of the index * @param fieldNames The field names under the table contained in the index. */ - private IndexImpl(IndexType indexType, String name, String[][] fieldNames) { + private IndexImpl( + IndexType indexType, String name, String[][] fieldNames, Map properties) { this.indexType = indexType; this.name = name; this.fieldNames = fieldNames; + this.properties = properties; } /** @@ -121,6 +147,11 @@ public String[][] fieldNames() { return fieldNames; } + @Override + public Map properties() { + return properties; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -132,12 +163,15 @@ public boolean equals(Object o) { IndexImpl index = (IndexImpl) o; return indexType == index.indexType && Objects.equal(name, index.name) - && Arrays.deepEquals(fieldNames, index.fieldNames); + && Arrays.deepEquals(fieldNames, index.fieldNames) + && Objects.equal(properties, index.properties); } @Override public int hashCode() { - return Objects.hashCode(indexType, name, Arrays.hashCode(fieldNames)); + int result = Objects.hashCode(indexType, name, properties); + result = 31 * result + Arrays.deepHashCode(fieldNames); + return result; } /** @@ -159,6 +193,9 @@ public static class Builder { /** The field names of the index. */ protected String[][] fieldNames; + /** The properties of the index. */ + protected Map properties; + /** * Set the type of the index. * @@ -192,13 +229,24 @@ public Indexes.IndexImpl.Builder withFieldNames(String[][] fieldNames) { return this; } + /** + * Set the properties of the index. + * + * @param properties The properties of the index + * @return The builder for creating a new instance of IndexImpl. + */ + public Indexes.IndexImpl.Builder withProperties(Map properties) { + this.properties = properties; + return this; + } + /** * Build a new instance of IndexImpl. * * @return The new instance. */ public Index build() { - return new IndexImpl(indexType, name, fieldNames); + return new IndexImpl(indexType, name, fieldNames, properties); } } } diff --git a/build.gradle.kts b/build.gradle.kts index be919c3c0d3..c01e709048f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -858,17 +858,6 @@ tasks { destinationDirectory.set(projectDir.dir("distribution")) } - val assembleLanceRESTServer by registering(Tar::class) { - dependsOn("compileLanceRESTServer") - group = "gravitino distribution" - finalizedBy("checksumLanceRESTServerDistribution") - into("${rootProject.name}-lance-rest-server-$version-bin") - from(compileLanceRESTServer.map { it.outputs.files.single() }) - compression = Compression.GZIP - archiveFileName.set("${rootProject.name}-lance-rest-server-$version-bin.tar.gz") - destinationDirectory.set(projectDir.dir("distribution")) - } - val assembleIcebergRESTServer by registering(Tar::class) { dependsOn("compileIcebergRESTServer") group = "gravitino distribution" @@ -880,6 +869,17 @@ tasks { destinationDirectory.set(projectDir.dir("distribution")) } + val assembleLanceRESTServer by registering(Tar::class) { + dependsOn("compileLanceRESTServer") + group = "gravitino distribution" + finalizedBy("checksumLanceRESTServerDistribution") + into("${rootProject.name}-lance-rest-server-$version-bin") + from(compileLanceRESTServer.map { it.outputs.files.single() }) + compression = Compression.GZIP + archiveFileName.set("${rootProject.name}-lance-rest-server-$version-bin.tar.gz") + destinationDirectory.set(projectDir.dir("distribution")) + } + register("checksumIcebergRESTServerDistribution") { group = "gravitino distribution" dependsOn(assembleIcebergRESTServer) @@ -914,7 +914,12 @@ tasks { register("checksumDistribution") { group = "gravitino distribution" - dependsOn(assembleDistribution, "checksumTrinoConnector", "checksumIcebergRESTServerDistribution", "checksumLanceRESTServerDistribution") + dependsOn( + assembleDistribution, + "checksumTrinoConnector", + "checksumIcebergRESTServerDistribution", + "checksumLanceRESTServerDistribution" + ) val archiveFile = assembleDistribution.flatMap { it.archiveFile } val checksumFile = archiveFile.map { archive -> archive.asFile.let { it.resolveSibling("${it.name}.sha256") } diff --git a/catalogs/catalog-lakehouse-generic/build.gradle.kts b/catalogs/catalog-lakehouse-generic/build.gradle.kts index 484514bede5..2dee42cf2d7 100644 --- a/catalogs/catalog-lakehouse-generic/build.gradle.kts +++ b/catalogs/catalog-lakehouse-generic/build.gradle.kts @@ -43,7 +43,11 @@ dependencies { implementation(libs.commons.io) implementation(libs.commons.lang3) implementation(libs.guava) - implementation(libs.lance) + implementation(libs.hadoop3.client.api) + implementation(libs.lance) { + exclude(group = "com.google.guava", module = "guava") // provided by gravitino + exclude(group = "org.apache.commons", module = "commons-lang3") // provided by gravitino + } annotationProcessor(libs.lombok) diff --git a/catalogs/catalog-lakehouse-generic/src/main/java/org/apache/gravitino/catalog/lakehouse/lance/LanceTableOperations.java b/catalogs/catalog-lakehouse-generic/src/main/java/org/apache/gravitino/catalog/lakehouse/lance/LanceTableOperations.java index 2eed5288be1..8d610c87b24 100644 --- a/catalogs/catalog-lakehouse-generic/src/main/java/org/apache/gravitino/catalog/lakehouse/lance/LanceTableOperations.java +++ b/catalogs/catalog-lakehouse-generic/src/main/java/org/apache/gravitino/catalog/lakehouse/lance/LanceTableOperations.java @@ -18,13 +18,24 @@ */ package org.apache.gravitino.catalog.lakehouse.lance; +import static org.apache.gravitino.lance.common.utils.LanceConstants.LANCE_INDEX_CONFIG_KEY; + import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import com.lancedb.lance.Dataset; import com.lancedb.lance.WriteParams; import com.lancedb.lance.index.DistanceType; import com.lancedb.lance.index.IndexParams; import com.lancedb.lance.index.IndexType; +import com.lancedb.lance.index.scalar.ScalarIndexParams; +import com.lancedb.lance.index.vector.HnswBuildParams; +import com.lancedb.lance.index.vector.IvfBuildParams; +import com.lancedb.lance.index.vector.PQBuildParams; +import com.lancedb.lance.index.vector.SQBuildParams; import com.lancedb.lance.index.vector.VectorIndexParams; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest.MetricTypeEnum; +import com.lancedb.lance.namespace.util.JsonUtil; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -159,12 +170,20 @@ public Table alterTable(NameIdentifier ident, TableChange... changes) .withIndexType(addIndexChange.getType()) .withName(addIndexChange.getName()) .withFieldNames(addIndexChange.getFieldNames()) + .withProperties(addIndexChange.getProperties()) .build(); }) .collect(Collectors.toList()); Table loadedTable = super.loadTable(ident); - addLanceIndex(loadedTable, addedIndexes); + + String location = loadedTable.properties().get(Table.PROPERTY_LOCATION); + List addedIndex = addLanceIndex(location, addedIndexes); + + // Since Lance supports adding indexes without an index name, and it will generate index name + // automatically, we need to modify the TableChange to include the index name after adding the + // index to Lance dataset. + changes = modifyAddIndex(changes, addedIndex); // After adding the index to the Lance dataset, we need to update the table metadata in // Gravitino. If there's any failure during this process, the code will throw an exception // and the update won't be applied in Gravitino. @@ -240,43 +259,135 @@ private org.apache.arrow.vector.types.pojo.Schema convertColumnsToArrowSchema(Co return new org.apache.arrow.vector.types.pojo.Schema(fields); } - private void addLanceIndex(Table table, List addedIndexes) { - String location = table.properties().get(Table.PROPERTY_LOCATION); - try (Dataset dataset = Dataset.open(location, new RootAllocator())) { - // For Lance, we only support adding indexes, so in fact, we can't handle drop index here. - for (Index index : addedIndexes) { - IndexType indexType = IndexType.valueOf(index.type().name()); - IndexParams indexParams = getIndexParamsByIndexType(indexType); + private TableChange[] modifyAddIndex(TableChange[] tableChanges, List addIndex) { + int indexCount = 0; + for (int i = 0; i < tableChanges.length; i++) { + TableChange change = tableChanges[i]; + if (change instanceof TableChange.AddIndex) { + Index index = addIndex.get(indexCount++); + tableChanges[i] = + new TableChange.AddIndex( + index.type(), index.name(), index.fieldNames(), index.properties()); + } + } + + return tableChanges; + } + private List addLanceIndex(String location, List addedIndexes) { + List newIndexes = Lists.newArrayList(); + try (RootAllocator rootAllocator = new RootAllocator(); + Dataset dataset = Dataset.open(location, rootAllocator)) { + for (Index index : addedIndexes) { + IndexType indexType = getIndexType(index); + IndexParams indexParams = generateIndexParams(index); dataset.createIndex( Arrays.stream(index.fieldNames()) - .map(field -> String.join(".", field)) + .map(fieldPath -> String.join(".", fieldPath)) .collect(Collectors.toList()), indexType, - Optional.of(index.name()), + Optional.ofNullable(index.name()), indexParams, - true); + false); + + // Currently lance only supports single-field indexes, so we can use the first field name. + // Another point is that we need to ensure the index name is not null in Gravitino, so we + // generate a name if it's null as Lance will generate a name automatically. + String lanceIndexName = + index.name() == null ? index.fieldNames()[0][0] + "_idx" : index.name(); + newIndexes.add( + Indexes.of(index.type(), lanceIndexName, index.fieldNames(), index.properties())); } - } catch (Exception e) { - throw new RuntimeException( - "Failed to add indexes to Lance dataset at location " + location, e); + + return newIndexes; } } - private IndexParams getIndexParamsByIndexType(IndexType indexType) { + private IndexType getIndexType(Index index) { + IndexType indexType = IndexType.valueOf(index.type().name()); + return switch (indexType) { + // API only supports these index types for now, but there are more index types in Lance. + case SCALAR, BTREE, INVERTED, BITMAP -> indexType; + // According to real test, we need to map IVF_SQ/IVF_PQ/IVF_HNSW_SQ to VECTOR type in Lance, + // or it will throw exception. For more, please refer to + // https://github.com/lancedb/lance/issues/5182#issuecomment-3524372490 + case IVF_FLAT, IVF_PQ, IVF_HNSW_SQ -> IndexType.VECTOR; + default -> throw new IllegalArgumentException("Unsupported index type: " + indexType); + }; + } + + private IndexParams generateIndexParams(Index index) { + IndexType indexType = IndexType.valueOf(index.type().name()); + + String configJson = index.properties().get(LANCE_INDEX_CONFIG_KEY); + Preconditions.checkArgument( + StringUtils.isNotBlank(configJson), + "Lance index config must be provided in index properties with key %s", + LANCE_INDEX_CONFIG_KEY); + CreateTableIndexRequest request; + try { + request = JsonUtil.mapper().readValue(configJson, CreateTableIndexRequest.class); + } catch (Exception e) { + throw new IllegalArgumentException("Lance index config is invalid", e); + } + + IndexParams.Builder builder = IndexParams.builder(); switch (indexType) { - case SCALAR: - return IndexParams.builder().build(); - case VECTOR: - // TODO make these parameters configurable - int numberOfDimensions = 3; // this value should be determined dynamically based on the data - // Add properties to Index to set this value. - return IndexParams.builder() - .setVectorIndexParams( - VectorIndexParams.ivfPq(2, 8, numberOfDimensions, DistanceType.L2, 2)) - .build(); - default: - throw new IllegalArgumentException("Unsupported index type: " + indexType); + case SCALAR, BTREE, INVERTED, BITMAP -> builder.setScalarIndexParams( + ScalarIndexParams.create(indexType.name())); + + case IVF_FLAT -> builder.setVectorIndexParams( + new VectorIndexParams.Builder(new IvfBuildParams.Builder().build()) + .setDistanceType(toLanceDistanceType(request.getMetricType())) + .build()); + case IVF_PQ -> builder.setVectorIndexParams( + new VectorIndexParams.Builder(new IvfBuildParams.Builder().build()) + .setDistanceType(toLanceDistanceType(request.getMetricType())) + .setPqParams( + new PQBuildParams.Builder() + .setNumSubVectors(1) // others use default value. + .build()) + .build()); + + case IVF_SQ -> builder.setVectorIndexParams( + new VectorIndexParams.Builder(new IvfBuildParams.Builder().build()) + .setDistanceType(toLanceDistanceType(request.getMetricType())) + .setSqParams(new SQBuildParams.Builder().build()) + .build()); + + case IVF_HNSW_SQ -> builder.setVectorIndexParams( + new VectorIndexParams.Builder(new IvfBuildParams.Builder().build()) + .setDistanceType(toLanceDistanceType(request.getMetricType())) + .setHnswParams(new HnswBuildParams.Builder().build()) + .build()); + + case IVF_HNSW_PQ -> builder.setVectorIndexParams( + new VectorIndexParams.Builder(new IvfBuildParams.Builder().build()) + .setDistanceType(toLanceDistanceType(request.getMetricType())) + .setHnswParams(new HnswBuildParams.Builder().build()) + .setPqParams( + new PQBuildParams.Builder() + .setNumSubVectors(1) // others use default value. + .build()) + .build()); + default -> throw new IllegalArgumentException("Unsupported index type: " + indexType); } + + return builder.build(); + } + + private DistanceType toLanceDistanceType(MetricTypeEnum metricTypeEnum) { + if (metricTypeEnum == null) { + // Default to L2 + return DistanceType.L2; + } + String metricName = metricTypeEnum.name(); + for (DistanceType distanceType : DistanceType.values()) { + if (distanceType.name().equalsIgnoreCase(metricName)) { + return distanceType; + } + } + + throw new IllegalArgumentException("Unsupported metric type: " + metricTypeEnum); } } diff --git a/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java b/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java index 8521501c015..831de85a2c5 100644 --- a/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java +++ b/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java @@ -226,7 +226,8 @@ static TableUpdateRequest toTableUpdateRequest(TableChange change) { return new TableUpdateRequest.AddTableIndexRequest( ((TableChange.AddIndex) change).getType(), ((TableChange.AddIndex) change).getName(), - ((TableChange.AddIndex) change).getFieldNames()); + ((TableChange.AddIndex) change).getFieldNames(), + ((TableChange.AddIndex) change).getProperties()); } else if (change instanceof TableChange.DeleteIndex) { return new TableUpdateRequest.DeleteTableIndexRequest( ((TableChange.DeleteIndex) change).getName(), diff --git a/common/src/main/java/org/apache/gravitino/dto/rel/indexes/IndexDTO.java b/common/src/main/java/org/apache/gravitino/dto/rel/indexes/IndexDTO.java index f165e59f925..28122db1ef0 100644 --- a/common/src/main/java/org/apache/gravitino/dto/rel/indexes/IndexDTO.java +++ b/common/src/main/java/org/apache/gravitino/dto/rel/indexes/IndexDTO.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.base.Preconditions; import java.util.Arrays; +import java.util.Map; import java.util.Objects; import org.apache.gravitino.json.JsonUtils.IndexDeserializer; import org.apache.gravitino.json.JsonUtils.IndexSerializer; @@ -38,6 +39,7 @@ public class IndexDTO implements Index { private IndexType indexType; private String name; private String[][] fieldNames; + private Map properties; /** Default constructor for Jackson deserialization. */ public IndexDTO() {} @@ -48,11 +50,14 @@ public IndexDTO() {} * @param indexType The type of the index. * @param name The name of the index. * @param fieldNames The names of the fields. + * @param properties The properties of the index. */ - public IndexDTO(IndexType indexType, String name, String[][] fieldNames) { + public IndexDTO( + IndexType indexType, String name, String[][] fieldNames, Map properties) { this.indexType = indexType; this.name = name; this.fieldNames = fieldNames; + this.properties = properties; } /** @@ -79,6 +84,11 @@ public String[][] fieldNames() { return fieldNames; } + @Override + public Map properties() { + return properties; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -87,7 +97,8 @@ public boolean equals(Object o) { return indexType == indexDTO.indexType && Objects.equals(name, indexDTO.name) - && compareStringArrays(fieldNames, indexDTO.fieldNames); + && compareStringArrays(fieldNames, indexDTO.fieldNames) + && Objects.equals(properties, indexDTO.properties); } private static boolean compareStringArrays(String[][] array1, String[][] array2) { @@ -108,6 +119,7 @@ public int hashCode() { for (String[] fieldName : fieldNames) { result = 31 * result + Arrays.hashCode(fieldName); } + result = 31 * result + Objects.hashCode(properties); return result; } @@ -131,6 +143,9 @@ public static class Builder { /** The names of the fields. */ protected String[][] fieldNames; + /** The properties of the index. */ + protected Map properties; + /** Default constructor. */ public Builder() {} @@ -167,6 +182,17 @@ public S withFieldNames(String[][] fieldNames) { return (S) this; } + /** + * Sets the properties of the index. + * + * @param properties The properties of the index. + * @return The builder. + */ + public S withProperties(Map properties) { + this.properties = properties; + return (S) this; + } + /** * Builds a new instance of IndexDTO. * @@ -177,7 +203,7 @@ public IndexDTO build() { Preconditions.checkArgument( fieldNames != null && fieldNames.length > 0, "The index must be set with corresponding column names"); - return new IndexDTO(indexType, name, fieldNames); + return new IndexDTO(indexType, name, fieldNames, properties); } } } diff --git a/common/src/main/java/org/apache/gravitino/dto/requests/TableUpdateRequest.java b/common/src/main/java/org/apache/gravitino/dto/requests/TableUpdateRequest.java index e31e14033ec..ce082f9146d 100644 --- a/common/src/main/java/org/apache/gravitino/dto/requests/TableUpdateRequest.java +++ b/common/src/main/java/org/apache/gravitino/dto/requests/TableUpdateRequest.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.base.Preconditions; import java.util.Arrays; +import java.util.Map; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -824,6 +825,19 @@ public AddTableIndexRequest(Index.IndexType type, String name, String[][] fieldN this.index = Indexes.of(type, name, fieldNames); } + /** + * The constructor of the add table index request. + * + * @param type Index type of the index to be added + * @param name Name of the index to be added + * @param fieldNames Field names of the index to be added + * @param properties Properties of the index to be added + */ + public AddTableIndexRequest( + Index.IndexType type, String name, String[][] fieldNames, Map properties) { + this.index = Indexes.of(type, name, fieldNames, properties); + } + /** * Validates the request. * @@ -843,7 +857,8 @@ public void validate() throws IllegalArgumentException { */ @Override public TableChange tableChange() { - return TableChange.addIndex(index.type(), index.name(), index.fieldNames()); + return TableChange.addIndex( + index.type(), index.name(), index.fieldNames(), index.properties()); } } diff --git a/common/src/main/java/org/apache/gravitino/json/JsonUtils.java b/common/src/main/java/org/apache/gravitino/json/JsonUtils.java index b35243b17bc..3a2b1531538 100644 --- a/common/src/main/java/org/apache/gravitino/json/JsonUtils.java +++ b/common/src/main/java/org/apache/gravitino/json/JsonUtils.java @@ -121,6 +121,7 @@ public class JsonUtils { private static final String INDEX_TYPE = "indexType"; private static final String INDEX_NAME = "name"; private static final String INDEX_FIELD_NAMES = "fieldNames"; + private static final String INDEX_PROPERTIES = "properties"; private static final String NUMBER = "number"; private static final String TYPE = "type"; private static final String STRUCT = "struct"; @@ -1473,6 +1474,12 @@ public void serialize(Index value, JsonGenerator gen, SerializerProvider seriali } gen.writeFieldName(INDEX_FIELD_NAMES); gen.writeObject(value.fieldNames()); + + if (value.properties() != null) { + gen.writeFieldName(INDEX_PROPERTIES); + gen.writeObject(value.properties()); + } + gen.writeEndObject(); } } @@ -1501,6 +1508,11 @@ public Index deserialize(JsonParser p, DeserializationContext ctxt) throws IOExc node.get(INDEX_FIELD_NAMES) .forEach(field -> fieldNames.add(getStringArray((ArrayNode) field))); builder.withFieldNames(fieldNames.toArray(new String[0][0])); + + if (node.has(INDEX_PROPERTIES)) { + builder.withProperties(getStringMapOrNull(INDEX_PROPERTIES, node)); + } + return builder.build(); } } diff --git a/common/src/test/java/org/apache/gravitino/json/TestSerializer.java b/common/src/test/java/org/apache/gravitino/json/TestSerializer.java index a8fff056225..46e40c5ed9f 100644 --- a/common/src/test/java/org/apache/gravitino/json/TestSerializer.java +++ b/common/src/test/java/org/apache/gravitino/json/TestSerializer.java @@ -133,10 +133,10 @@ void testIndexImplSerializer() throws JsonProcessingException { IndexType.UNIQUE_KEY, "index_2", new String[][] {new String[] {"col1"}, new String[] {"col2"}}); - actualJson = JsonUtils.anyFieldMapper().writeValueAsString(index); + actualJson = JsonUtils.anyFieldMapper().writeValueAsString(DTOConverters.toDTO(index)); expectedJson = """ - {"indexType":"unique_key","name":"index_2","fieldNames":[["col1"],["col2"]]}"""; + {"indexType":"UNIQUE_KEY","name":"index_2","fieldNames":[["col1"],["col2"]]}"""; Assertions.assertEquals(expectedJson, actualJson); } diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 7e0867dc713..bc917fcdfe8 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -47,8 +47,8 @@ dependencies { exclude(group = "com.fasterxml.jackson.core", module = "*") // provided by gravitino exclude(group = "com.fasterxml.jackson.datatype", module = "*") // provided by gravitino exclude(group = "commons-codec", module = "commons-codec") // provided by jcasbin - exclude(group = "com.google.guava", module = "guava") // provided by gravitino exclude(group = "org.apache.commons", module = "commons-lang3") // provided by gravitino + exclude(group = "com.google.guava", module = "guava") // provided by gravitino exclude(group = "org.junit.jupiter", module = "*") // provided by test scope } implementation(libs.mybatis) diff --git a/core/src/main/java/org/apache/gravitino/catalog/ManagedTableOperations.java b/core/src/main/java/org/apache/gravitino/catalog/ManagedTableOperations.java index 44060c6caef..178369268e4 100644 --- a/core/src/main/java/org/apache/gravitino/catalog/ManagedTableOperations.java +++ b/core/src/main/java/org/apache/gravitino/catalog/ManagedTableOperations.java @@ -248,6 +248,7 @@ private TableEntity applyChanges(TableEntity oldTableEntity, TableChange... chan .withName(addIndex.getName()) .withFieldNames(addIndex.getFieldNames()) .withIndexType(addIndex.getType()) + .withProperties(addIndex.getProperties()) .build(); newIndexes.add(newIndex); diff --git a/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java b/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java index 87cda80432b..fe3b29bceb7 100644 --- a/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java +++ b/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java @@ -555,7 +555,6 @@ public static TableEntity fromTableAndColumnPOs( .readValue(tablePO.getPartitions(), Partitioning[].class)) .withComment(tablePO.getComment()) .withProperties(properties) - .withColumns(fromColumnPOs(columnPOs)) .build(); } catch (JsonProcessingException e) { throw new RuntimeException("Failed to deserialize json object:", e); diff --git a/core/src/test/java/org/apache/gravitino/storage/relational/TestJDBCBackend.java b/core/src/test/java/org/apache/gravitino/storage/relational/TestJDBCBackend.java index cbe019d2f7a..6b54694d4ed 100644 --- a/core/src/test/java/org/apache/gravitino/storage/relational/TestJDBCBackend.java +++ b/core/src/test/java/org/apache/gravitino/storage/relational/TestJDBCBackend.java @@ -1248,13 +1248,13 @@ void testUpdateAndDropLanceTable() throws IOException, InterruptedException { .withName("table") .withAuditInfo(auditInfo) .withComment(null) - .withProperties(ImmutableMap.of("format", "lance", "location", "/tmp/test/lance")) + .withProperties(ImmutableMap.of("format", "LANCE", "location", "/tmp/test/lance")) .build(); backend.insert(table, false); TableEntity fetchedTable = backend.get(table.nameIdentifier(), Entity.EntityType.TABLE); - Assertions.assertEquals("lance", fetchedTable.properties().get("format")); + Assertions.assertEquals("LANCE", fetchedTable.properties().get("format")); TableEntity updatedTable = TableEntity.builder() diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4d650636eba..10c00970ce9 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -29,7 +29,7 @@ lombok = "1.18.20" slf4j = "2.0.16" log4j = "2.24.3" lance = "0.39.0" -lance-namespace = "0.0.19" +lance-namespace = "0.0.20" jetty = "9.4.51.v20230217" jersey = "2.41" mockito = "4.11.0" diff --git a/lance/lance-common/build.gradle.kts b/lance/lance-common/build.gradle.kts index e3a669212cb..d51d4aa9ee3 100644 --- a/lance/lance-common/build.gradle.kts +++ b/lance/lance-common/build.gradle.kts @@ -1,20 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ description = "lance-common" diff --git a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/LanceTableOperations.java b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/LanceTableOperations.java index 97b65d9bf0b..8a5c6c806ac 100644 --- a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/LanceTableOperations.java +++ b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/LanceTableOperations.java @@ -19,11 +19,17 @@ package org.apache.gravitino.lance.common.ops; import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexResponse; import com.lancedb.lance.namespace.model.CreateTableRequest; import com.lancedb.lance.namespace.model.CreateTableResponse; import com.lancedb.lance.namespace.model.DeregisterTableResponse; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsRequest; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsResponse; import com.lancedb.lance.namespace.model.DescribeTableResponse; import com.lancedb.lance.namespace.model.DropTableResponse; +import com.lancedb.lance.namespace.model.ListTableIndicesRequest; +import com.lancedb.lance.namespace.model.ListTableIndicesResponse; import com.lancedb.lance.namespace.model.RegisterTableRequest; import com.lancedb.lance.namespace.model.RegisterTableResponse; import java.util.Map; @@ -96,7 +102,7 @@ RegisterTableResponse registerTable( */ DeregisterTableResponse deregisterTable(String tableId, String delimiter); - /** + /* * Check if a table exists. * * @param tableId table ids are in the format of "{namespace}{delimiter}{table_name}" @@ -113,4 +119,38 @@ RegisterTableResponse registerTable( * @return the response of the drop table operation */ DropTableResponse dropTable(String tableId, String delimiter); + + /** + * Create an index on a Lance table. + * + * @param tableId table ids are in the format of "{namespace}{delimiter}{table_name}" + * @param delimiter the delimiter used in the namespace + * @param request the request containing index creation details + * @return the response of the create index operation. + */ + CreateTableIndexResponse createTableIndex( + String tableId, String delimiter, CreateTableIndexRequest request); + + /** + * List indices of a Lance table. + * + * @param tableId table ids are in the format of "{namespace}{delimiter}{table_name}" + * @param delimiter the delimiter used in the namespace + * @param request the request containing table id and other parameters + * @return the response containing the list of indices + */ + ListTableIndicesResponse listTableIndices( + String tableId, String delimiter, ListTableIndicesRequest request); + + /** + * Describe statistics of a specific index on a Lance table. + * + * @param tableId table ids are in the format of "{namespace}{delimiter}{table_name}" + * @param delimiter the delimiter used in the namespace + * @param indexId the identifier of the index to describe + * @param request the request containing table id and other parameters + * @return the response containing index statistics + */ + DescribeTableIndexStatsResponse describeTableIndexStats( + String tableId, String delimiter, String indexId, DescribeTableIndexStatsRequest request); } diff --git a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/gravitino/GravitinoLanceTableOperations.java b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/gravitino/GravitinoLanceTableOperations.java index 68add7a863d..5ebf7df60dc 100644 --- a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/gravitino/GravitinoLanceTableOperations.java +++ b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/ops/gravitino/GravitinoLanceTableOperations.java @@ -20,26 +20,36 @@ package org.apache.gravitino.lance.common.ops.gravitino; import static org.apache.gravitino.lance.common.ops.gravitino.LanceDataTypeConverter.CONVERTER; +import static org.apache.gravitino.lance.common.utils.LanceConstants.LANCE_INDEX_CONFIG_KEY; import static org.apache.gravitino.lance.common.utils.LanceConstants.LANCE_LOCATION; import static org.apache.gravitino.rel.Column.DEFAULT_VALUE_NOT_SET; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.lancedb.lance.namespace.LanceNamespaceException; import com.lancedb.lance.namespace.ObjectIdentifier; import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexResponse; import com.lancedb.lance.namespace.model.CreateTableRequest; import com.lancedb.lance.namespace.model.CreateTableRequest.ModeEnum; import com.lancedb.lance.namespace.model.CreateTableResponse; import com.lancedb.lance.namespace.model.DeregisterTableResponse; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsRequest; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsResponse; import com.lancedb.lance.namespace.model.DescribeTableResponse; import com.lancedb.lance.namespace.model.DropTableResponse; +import com.lancedb.lance.namespace.model.IndexContent; import com.lancedb.lance.namespace.model.JsonArrowSchema; +import com.lancedb.lance.namespace.model.ListTableIndicesRequest; +import com.lancedb.lance.namespace.model.ListTableIndicesResponse; import com.lancedb.lance.namespace.model.RegisterTableRequest; import com.lancedb.lance.namespace.model.RegisterTableResponse; import com.lancedb.lance.namespace.util.CommonUtil; import com.lancedb.lance.namespace.util.JsonArrowSchemaConverter; +import com.lancedb.lance.namespace.util.JsonUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -57,6 +67,9 @@ import org.apache.gravitino.lance.common.utils.LancePropertiesUtils; import org.apache.gravitino.rel.Column; import org.apache.gravitino.rel.Table; +import org.apache.gravitino.rel.TableChange; +import org.apache.gravitino.rel.TableChange.AddIndex; +import org.apache.gravitino.rel.indexes.Index.IndexType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -324,6 +337,78 @@ public DropTableResponse dropTable(String tableId, String delimiter) { return response; } + // Note: Create indices is an asynchronous operation in Lance Lakehouse. + @Override + public CreateTableIndexResponse createTableIndex( + String tableId, String delimiter, CreateTableIndexRequest request) { + ObjectIdentifier nsId = ObjectIdentifier.of(tableId, Pattern.quote(delimiter)); + Preconditions.checkArgument( + nsId.levels() == 3, "Expected at 3-level namespace but got: %s", nsId.levels()); + + String catalogName = nsId.levelAtListPos(0); + Catalog catalog = namespaceWrapper.loadAndValidateLakehouseCatalog(catalogName); + + NameIdentifier tableIdentifier = + NameIdentifier.of(nsId.levelAtListPos(1), nsId.levelAtListPos(2)); + + // There seem to be missing index name in the request, using Optional.empty() for now. + TableChange tableChange = buildAddIndex(Optional.empty(), request); + + Table table = catalog.asTableCatalog().alterTable(tableIdentifier, tableChange); + CreateTableIndexResponse response = new CreateTableIndexResponse(); + response.setId(nsId.listStyleId()); + response.setLocation(table.properties().get(LANCE_LOCATION)); + response.setProperties(table.properties()); + return response; + } + + @Override + public ListTableIndicesResponse listTableIndices( + String tableId, String delimiter, ListTableIndicesRequest request) { + ObjectIdentifier nsId = ObjectIdentifier.of(tableId, Pattern.quote(delimiter)); + Preconditions.checkArgument( + nsId.levels() == 3, "Expected at 3-level namespace but got: %s", nsId.levels()); + + String catalogName = nsId.levelAtListPos(0); + Catalog catalog = namespaceWrapper.loadAndValidateLakehouseCatalog(catalogName); + NameIdentifier tableIdentifier = + NameIdentifier.of(nsId.levelAtListPos(1), nsId.levelAtListPos(2)); + + Table table = catalog.asTableCatalog().loadTable(tableIdentifier); + ListTableIndicesResponse response = new ListTableIndicesResponse(); + List contents = + Arrays.stream(table.index()) + .map( + index -> { + IndexContent content = new IndexContent(); + List columnNames = new ArrayList<>(); + for (int i = 0; i < index.fieldNames().length; i++) { + columnNames.add(index.fieldNames()[i][0]); + } + content.setColumns(columnNames); + content.setIndexName(index.name()); + + // Currently there is no API to get index status, setting all indexes to READY for + // simplicity. So please note that this status may not reflect the actual index + // status. + content.setIndexUuid(index.name()); + content.setStatus("READY"); + return content; + }) + .collect(Collectors.toList()); + response.setIndexes(contents); + response.setPageToken(request.getPageToken()); + return response; + } + + @Override + public DescribeTableIndexStatsResponse describeTableIndexStats( + String tableId, String delimiter, String indexId, DescribeTableIndexStatsRequest request) { + // Do not support now as Lance dataset index creation is an asynchronous operation, and Lance + // dataset does not have index stats API now. + throw new UnsupportedOperationException("Describing table index stats is not supported now."); + } + private List extractColumns(org.apache.arrow.vector.types.pojo.Schema arrowSchema) { List columns = new ArrayList<>(); @@ -349,4 +434,18 @@ private JsonArrowSchema toJsonArrowSchema(Column[] columns) { return JsonArrowSchemaConverter.convertToJsonArrowSchema( new org.apache.arrow.vector.types.pojo.Schema(fields)); } + + private AddIndex buildAddIndex(Optional indexName, CreateTableIndexRequest request) { + try { + String requestJson = JsonUtil.mapper().writeValueAsString(request); + return new AddIndex( + IndexType.valueOf(request.getIndexType().name()), + indexName.orElse(null), + // It seems that only a single column index is supported for now. + new String[][] {{request.getColumn()}}, + ImmutableMap.of(LANCE_INDEX_CONFIG_KEY, requestJson)); + } catch (Exception e) { + throw new RuntimeException("Failed to build AddIndex from CreateTableIndexRequest", e); + } + } } diff --git a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/utils/LanceConstants.java b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/utils/LanceConstants.java index c34a7be58a2..865e943abd4 100644 --- a/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/utils/LanceConstants.java +++ b/lance/lance-common/src/main/java/org/apache/gravitino/lance/common/utils/LanceConstants.java @@ -31,4 +31,6 @@ public class LanceConstants { // Prefix for storage options in LanceConfig public static final String LANCE_STORAGE_OPTIONS_PREFIX = "lance.storage."; + + public static final String LANCE_INDEX_CONFIG_KEY = "lance_index_config"; } diff --git a/lance/lance-rest-server/build.gradle.kts b/lance/lance-rest-server/build.gradle.kts index a74fe4819d1..fac788213fc 100644 --- a/lance/lance-rest-server/build.gradle.kts +++ b/lance/lance-rest-server/build.gradle.kts @@ -37,7 +37,10 @@ dependencies { } implementation(project(":lance:lance-common")) - implementation(libs.lance) + implementation(libs.lance) { + exclude(group = "org.apache.commons", module = "commons-lang3") // provided by gravitino + exclude(group = "com.google.guava", module = "guava") // provided by gravitino + } implementation(libs.commons.lang3) implementation(libs.bundles.jetty) diff --git a/lance/lance-rest-server/src/main/java/org/apache/gravitino/lance/service/rest/LanceTableOperations.java b/lance/lance-rest-server/src/main/java/org/apache/gravitino/lance/service/rest/LanceTableOperations.java index ab46353c832..b18f6497772 100644 --- a/lance/lance-rest-server/src/main/java/org/apache/gravitino/lance/service/rest/LanceTableOperations.java +++ b/lance/lance-rest-server/src/main/java/org/apache/gravitino/lance/service/rest/LanceTableOperations.java @@ -28,14 +28,20 @@ import com.google.common.collect.Maps; import com.lancedb.lance.namespace.model.CreateEmptyTableRequest; import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexResponse; import com.lancedb.lance.namespace.model.CreateTableRequest; import com.lancedb.lance.namespace.model.CreateTableResponse; import com.lancedb.lance.namespace.model.DeregisterTableRequest; import com.lancedb.lance.namespace.model.DeregisterTableResponse; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsRequest; +import com.lancedb.lance.namespace.model.DescribeTableIndexStatsResponse; import com.lancedb.lance.namespace.model.DescribeTableRequest; import com.lancedb.lance.namespace.model.DescribeTableResponse; import com.lancedb.lance.namespace.model.DropTableRequest; import com.lancedb.lance.namespace.model.DropTableResponse; +import com.lancedb.lance.namespace.model.ListTableIndicesRequest; +import com.lancedb.lance.namespace.model.ListTableIndicesResponse; import com.lancedb.lance.namespace.model.RegisterTableRequest; import com.lancedb.lance.namespace.model.RegisterTableRequest.ModeEnum; import com.lancedb.lance.namespace.model.RegisterTableResponse; @@ -55,6 +61,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.Response; +import org.apache.arrow.util.Preconditions; import org.apache.gravitino.lance.common.ops.NamespaceWrapper; import org.apache.gravitino.lance.common.utils.SerializationUtils; import org.apache.gravitino.lance.service.LanceExceptionMapper; @@ -154,7 +161,7 @@ public Response createEmptyTable( @ResponseMetered(name = "register-table", absolute = true) public Response registerTable( @PathParam("id") String tableId, - @QueryParam("delimiter") @DefaultValue("$") String delimiter, + @QueryParam("delimiter") @DefaultValue(NAMESPACE_DELIMITER_DEFAULT) String delimiter, @Context HttpHeaders headers, RegisterTableRequest registerTableRequest) { try { @@ -182,7 +189,7 @@ public Response registerTable( @ResponseMetered(name = "deregister-table", absolute = true) public Response deregisterTable( @PathParam("id") String tableId, - @QueryParam("delimiter") @DefaultValue("$") String delimiter, + @QueryParam("delimiter") @DefaultValue(NAMESPACE_DELIMITER_DEFAULT) String delimiter, @Context HttpHeaders headers, DeregisterTableRequest deregisterTableRequest) { try { @@ -195,6 +202,25 @@ public Response deregisterTable( } } + @POST + @Path("/create_index") + @Timed(name = "create-table-index." + MetricNames.HTTP_PROCESS_DURATION, absolute = true) + @ResponseMetered(name = "create-table-index", absolute = true) + public Response createTableIndex( + @PathParam("id") String tableId, + @QueryParam("delimiter") @DefaultValue(NAMESPACE_DELIMITER_DEFAULT) String delimiter, + @Context HttpHeaders headers, + CreateTableIndexRequest createTableIndexRequest) { + try { + validateCreateTableIndexRequest(createTableIndexRequest); + CreateTableIndexResponse response = + lanceNamespace.asTableOps().createTableIndex(tableId, delimiter, createTableIndexRequest); + return Response.ok(response).build(); + } catch (Exception e) { + return LanceExceptionMapper.toRESTResponse(tableId, e); + } + } + @POST @Path("/exists") @Timed(name = "table-exists." + MetricNames.HTTP_PROCESS_DURATION, absolute = true) @@ -235,13 +261,57 @@ public Response dropTable( } } + @POST + @Path("/index/list") + @Timed(name = "list-table-indices." + MetricNames.HTTP_PROCESS_DURATION, absolute = true) + @ResponseMetered(name = "list-table-indices", absolute = true) + public Response listTableIndices( + @PathParam("id") String tableId, + @QueryParam("delimiter") @DefaultValue(NAMESPACE_DELIMITER_DEFAULT) String delimiter, + @Context HttpHeaders headers, + ListTableIndicesRequest listTableIndicesRequest) { + try { + validateListTableIndicesRequest(listTableIndicesRequest); + ListTableIndicesResponse response = + lanceNamespace.asTableOps().listTableIndices(tableId, delimiter, listTableIndicesRequest); + return Response.ok(response).build(); + } catch (Exception e) { + return LanceExceptionMapper.toRESTResponse(tableId, e); + } + } + + @POST + @Path("/index/{index_name}/stats") + @Timed(name = "describe-table-index." + MetricNames.HTTP_PROCESS_DURATION, absolute = true) + @ResponseMetered(name = "describe-table-index", absolute = true) + public Response describeTableIndex( + @PathParam("id") String tableId, + @PathParam("index_name") String indexName, + @QueryParam("delimiter") @DefaultValue(NAMESPACE_DELIMITER_DEFAULT) String delimiter, + @Context HttpHeaders headers, + DescribeTableIndexStatsRequest describeTableIndexStatsRequest) { + try { + validateDescribeTableIndexRequest(describeTableIndexStatsRequest); + DescribeTableIndexStatsResponse response = + lanceNamespace + .asTableOps() + .describeTableIndexStats( + tableId, delimiter, indexName, describeTableIndexStatsRequest); + return Response.ok(response).build(); + } catch (Exception e) { + return LanceExceptionMapper.toRESTResponse(tableId, e); + } + } + private void validateCreateEmptyTableRequest( @SuppressWarnings("unused") CreateEmptyTableRequest request) { + // We will ignore the id in the request body since it's already provided in the path param. // No specific fields to validate for now } private void validateRegisterTableRequest( @SuppressWarnings("unused") RegisterTableRequest request) { + // We will ignore the id in the request body since it's already provided in the path param. // No specific fields to validate for now } @@ -257,11 +327,28 @@ private void validateDescribeTableRequest( // No specific fields to validate for now } + private void validateCreateTableIndexRequest( + @SuppressWarnings("unused") CreateTableIndexRequest request) { + Preconditions.checkArgument(request != null, "CreateTableIndexRequest must not be null"); + } + + private void validateListTableIndicesRequest( + @SuppressWarnings("unused") ListTableIndicesRequest request) { + // We will ignore the id in the request body since it's already provided in the path param + // No specific fields to validate for now + } + private void validateTableExists(@SuppressWarnings("unused") TableExistsRequest request) { // We will ignore the id in the request body since it's already provided in the path param // No specific fields to validate for now } + private void validateDescribeTableIndexRequest( + @SuppressWarnings("unused") DescribeTableIndexStatsRequest request) { + // We will ignore the id in the request body since it's already provided in the path param + // No specific fields to validate for now + } + private void validateDropTableRequest(@SuppressWarnings("unused") DropTableRequest request) { // We will ignore the id in the request body since it's already provided in the path param // No specific fields to validate for now diff --git a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/integration/test/LanceRESTServiceIT.java b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/integration/test/LanceRESTServiceIT.java index 028d98215e8..4933af4870e 100644 --- a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/integration/test/LanceRESTServiceIT.java +++ b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/integration/test/LanceRESTServiceIT.java @@ -18,9 +18,19 @@ */ package org.apache.gravitino.lance.integration.test; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; +import com.lancedb.lance.Dataset; +import com.lancedb.lance.Fragment; +import com.lancedb.lance.FragmentMetadata; +import com.lancedb.lance.Transaction; +import com.lancedb.lance.WriteParams; +import com.lancedb.lance.ipc.LanceScanner; +import com.lancedb.lance.ipc.ScanOptions; import com.lancedb.lance.namespace.LanceNamespace; import com.lancedb.lance.namespace.LanceNamespaceException; import com.lancedb.lance.namespace.LanceNamespaces; @@ -29,6 +39,10 @@ import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; import com.lancedb.lance.namespace.model.CreateNamespaceRequest; import com.lancedb.lance.namespace.model.CreateNamespaceResponse; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest.IndexTypeEnum; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest.MetricTypeEnum; +import com.lancedb.lance.namespace.model.CreateTableIndexResponse; import com.lancedb.lance.namespace.model.CreateTableRequest; import com.lancedb.lance.namespace.model.CreateTableResponse; import com.lancedb.lance.namespace.model.DeregisterTableRequest; @@ -42,9 +56,12 @@ import com.lancedb.lance.namespace.model.DropTableRequest; import com.lancedb.lance.namespace.model.DropTableResponse; import com.lancedb.lance.namespace.model.ErrorResponse; +import com.lancedb.lance.namespace.model.IndexContent; import com.lancedb.lance.namespace.model.JsonArrowField; import com.lancedb.lance.namespace.model.ListNamespacesRequest; import com.lancedb.lance.namespace.model.ListNamespacesResponse; +import com.lancedb.lance.namespace.model.ListTableIndicesRequest; +import com.lancedb.lance.namespace.model.ListTableIndicesResponse; import com.lancedb.lance.namespace.model.ListTablesRequest; import com.lancedb.lance.namespace.model.NamespaceExistsRequest; import com.lancedb.lance.namespace.model.RegisterTableRequest; @@ -52,9 +69,11 @@ import com.lancedb.lance.namespace.model.RegisterTableResponse; import com.lancedb.lance.namespace.model.TableExistsRequest; import com.lancedb.lance.namespace.rest.RestNamespaceConfig; +import com.lancedb.lance.operation.Append; import java.io.File; import java.io.IOException; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; @@ -63,11 +82,21 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Random; import java.util.Set; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.commons.io.FileUtils; import org.apache.gravitino.Catalog; import org.apache.gravitino.NameIdentifier; @@ -762,6 +791,240 @@ void testDropTable() { Assertions.assertEquals(404, exception.getCode()); } + @Test + void testCreateTableIndex() throws IOException { + catalog = createCatalog(CATALOG_NAME); + createSchema(); + List ids = List.of(CATALOG_NAME, SCHEMA_NAME, "non_existing_table"); + + // We need to create a table first; + org.apache.arrow.vector.types.pojo.Schema schema = + new org.apache.arrow.vector.types.pojo.Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("value", new ArrowType.Utf8()), + new Field( + "vector", + FieldType.nullable(new ArrowType.FixedSizeList(4)), + ImmutableList.of( + Field.nullable( + "fake", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)))))); + byte[] body = ArrowUtils.generateIpcStream(schema); + + CreateTableRequest request = new CreateTableRequest(); + request.setId(ids); + request.setLocation(tempDir + "/" + "table_for_index/"); + request.setProperties( + ImmutableMap.of( + "key1", "v1", + "lance.storage.a", "value_a", + "lance.storage.c", "value_c")); + + CreateTableResponse response = ns.createTable(request, body); + Assertions.assertEquals(request.getLocation(), response.getLocation()); + + writeDataToLance(request.getLocation()); + + // Now try to create Btree index on an existing table + CreateTableIndexRequest createTableIndexRequest = new CreateTableIndexRequest(); + createTableIndexRequest.setId(ids); + createTableIndexRequest.setIndexType(IndexTypeEnum.BTREE); + createTableIndexRequest.setColumn("id"); + createTableIndexRequest.setMetricType(MetricTypeEnum.L2); + CreateTableIndexResponse createTableIndexResponse = + Assertions.assertDoesNotThrow(() -> ns.createTableIndex(createTableIndexRequest)); + Assertions.assertNotNull(createTableIndexResponse); + + // Now try to create bitmap index on an existing table + createTableIndexRequest.setIndexType(IndexTypeEnum.BITMAP); + createTableIndexRequest.setColumn("value"); + createTableIndexResponse = + Assertions.assertDoesNotThrow(() -> ns.createTableIndex(createTableIndexRequest)); + Assertions.assertNotNull(createTableIndexResponse); + List indices = listIndices(request.getLocation()); + Assertions.assertEquals(2, indices.size()); + // Now try to create vector index on an existing table + createTableIndexRequest.setIndexType(IndexTypeEnum.IVF_FLAT); + createTableIndexRequest.setColumn("vector"); + createTableIndexResponse = + Assertions.assertDoesNotThrow(() -> ns.createTableIndex(createTableIndexRequest)); + Assertions.assertNotNull(createTableIndexResponse); + + ListTableIndicesRequest listTableIndicesRequest = new ListTableIndicesRequest(); + listTableIndicesRequest.setId(ids); + ListTableIndicesResponse listTableIndicesResponse = + ns.listTableIndices(listTableIndicesRequest); + Assertions.assertEquals(3, listTableIndicesResponse.getIndexes().size()); + List expectedIndexName = listIndices(request.getLocation()); + for (IndexContent indexContent : listTableIndicesResponse.getIndexes()) { + Assertions.assertTrue( + expectedIndexName.contains(indexContent.getIndexName()), + "Index name should be in the expected index names."); + if (indexContent.getIndexName().equals("id_idx")) { + Assertions.assertEquals("id", indexContent.getColumns().get(0)); + } else if (indexContent.getIndexName().equals("value_idx")) { + Assertions.assertEquals("value", indexContent.getColumns().get(0)); + } else if (indexContent.getIndexName().equals("vector_idx")) { + Assertions.assertEquals("vector", indexContent.getColumns().get(0)); + } + } + + // create another table to test other index types + ids = List.of(CATALOG_NAME, SCHEMA_NAME, "table_for_other_indexes"); + request.setId(ids); + request.setLocation(tempDir + "/" + "table_for_other_indexes/"); + response = ns.createTable(request, body); + Assertions.assertEquals(request.getLocation(), response.getLocation()); + writeDataToLance(request.getLocation()); + + // Now try to create FTS index on an existing table + createTableIndexRequest.setId(ids); + createTableIndexRequest.setIndexType(IndexTypeEnum.FTS); + createTableIndexRequest.setColumn("value"); + + LanceNamespaceException exception = + Assertions.assertThrows( + LanceNamespaceException.class, () -> ns.createTableIndex(createTableIndexRequest)); + // com.lancedb.lance.index.IndexType does not have FTS yet, so it should throw exception + Assertions.assertTrue( + exception.getMessage().contains("No enum constant com.lancedb.lance.index.IndexType.FTS")); + } + + private List listIndices(String lanceTableLocation) { + try (Dataset dataset = Dataset.open(lanceTableLocation)) { + return dataset.listIndexes(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void writeDataToLance(String tableLocation) { + try (Dataset dataset = Dataset.open(tableLocation)) { + org.apache.arrow.vector.types.pojo.Schema lanceSchema = dataset.getSchema(); + Transaction trans = + dataset + .newTransactionBuilder() + .operation( + Append.builder() + .fragments( + createFragmentMetadata(tableLocation, generateLanceData(), lanceSchema)) + .build()) + .writeParams(ImmutableMap.of()) + .build(); + + Dataset newDataset = dataset.commitTransaction(trans); + + try (LanceScanner scanner = + newDataset.newScan( + new ScanOptions.Builder() + .columns(Arrays.asList("id", "value", "vector")) + .batchSize(1000) + .build())) { + + List dataValues = com.google.common.collect.Lists.newArrayList(); + try (ArrowReader reader = scanner.scanBatches()) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + List fieldVectors = root.getFieldVectors(); + + IntVector ids = (IntVector) fieldVectors.get(0); + VarCharVector values = (VarCharVector) fieldVectors.get(1); + FixedSizeListVector vectors = (FixedSizeListVector) fieldVectors.get(2); + + for (int i = 0; i < root.getRowCount(); i++) { + int id = ids.get(i); + String value = new String(values.get(i), StandardCharsets.UTF_8); + List vector = com.google.common.collect.Lists.newArrayList(); + for (int j = 0; j < 4; j++) { + Float floatValue = ((Float4Vector) vectors.getDataVector()).get(i * 4 + j); + vector.add(floatValue); + } + + dataValues.add(new LanceDataValue(id, value, vector)); + } + } + } + + Assertions.assertEquals(5120, dataValues.size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private List generateLanceData() { + List updates = Lists.newArrayList(); + Random random = new Random(); + for (int i = 0; i < 5120; i++) { + LanceDataValue data = + new LanceDataValue( + i, + "value_" + i, + Arrays.asList( + (float) random.nextInt(10000), + (float) random.nextInt(10000), + (float) random.nextInt(10000), + (float) random.nextInt(10000))); + updates.add(data); + } + + return updates; + } + + private List createFragmentMetadata( + String tableLocation, + List updates, + org.apache.arrow.vector.types.pojo.Schema schema) + throws JsonProcessingException { + List fragmentMetas; + int count = 0; + RootAllocator rootAllocator = new RootAllocator(); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, rootAllocator)) { + for (FieldVector vector : root.getFieldVectors()) { + vector.setInitialCapacity(count); + } + root.allocateNew(); + + IntVector ids = (IntVector) root.getVector("id"); + VarCharVector values = (VarCharVector) root.getVector("value"); + FixedSizeListVector vectors = (FixedSizeListVector) root.getVector("vector"); + vectors.allocateNew(); + Float4Vector dataVector = (Float4Vector) vectors.getDataVector(); + + int index = 0; + for (LanceDataValue data : updates) { + ids.setSafe(index, data.id); + values.setSafe(index, data.value.getBytes(StandardCharsets.UTF_8)); + vectors.setNotNull(index); + for (int i = 0; i < 4; i++) { + Float floatValue = data.vector.get(i); + dataVector.setSafe(index * 4 + i, floatValue); + } + index++; + } + root.setRowCount(index); + + fragmentMetas = + Fragment.create(tableLocation, rootAllocator, root, new WriteParams.Builder().build()); + return fragmentMetas; + } + } + + static class LanceDataValue { + + public Integer id; + public String value; + public List vector; + + public LanceDataValue(Integer id, String value, List vector) { + this.id = id; + this.value = value; + this.vector = vector; + } + } + private GravitinoMetalake createMetalake(String metalakeName) { return client.createMetalake(metalakeName, "metalake for lance rest service tests", null); } diff --git a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/MockServletRequestFactory.java b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/MockServletRequestFactory.java new file mode 100644 index 00000000000..de85b4ce57c --- /dev/null +++ b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/MockServletRequestFactory.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.gravitino.lance.service.rest; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import javax.servlet.http.HttpServletRequest; + +public class MockServletRequestFactory extends ServletRequestFactoryBase { + @Override + public HttpServletRequest get() { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteUser()).thenReturn(null); + return request; + } +} diff --git a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceNamespaceOperations.java b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceNamespaceOperations.java index 02d5a6e812d..a887de3561d 100644 --- a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceNamespaceOperations.java +++ b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceNamespaceOperations.java @@ -21,32 +21,19 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.lancedb.lance.namespace.LanceNamespaceException; -import com.lancedb.lance.namespace.model.CreateEmptyTableRequest; -import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; import com.lancedb.lance.namespace.model.CreateNamespaceRequest; import com.lancedb.lance.namespace.model.CreateNamespaceResponse; -import com.lancedb.lance.namespace.model.CreateTableResponse; -import com.lancedb.lance.namespace.model.DeregisterTableRequest; -import com.lancedb.lance.namespace.model.DeregisterTableResponse; import com.lancedb.lance.namespace.model.DescribeNamespaceResponse; -import com.lancedb.lance.namespace.model.DescribeTableRequest; -import com.lancedb.lance.namespace.model.DescribeTableResponse; import com.lancedb.lance.namespace.model.DropNamespaceRequest; import com.lancedb.lance.namespace.model.DropNamespaceResponse; -import com.lancedb.lance.namespace.model.DropTableResponse; import com.lancedb.lance.namespace.model.ErrorResponse; import com.lancedb.lance.namespace.model.ListNamespacesResponse; -import com.lancedb.lance.namespace.model.RegisterTableRequest; -import com.lancedb.lance.namespace.model.RegisterTableResponse; import java.io.IOException; import java.util.regex.Pattern; import javax.servlet.http.HttpServletRequest; @@ -55,7 +42,6 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.apache.gravitino.exceptions.NoSuchCatalogException; -import org.apache.gravitino.lance.common.ops.LanceTableOperations; import org.apache.gravitino.lance.common.ops.NamespaceWrapper; import org.apache.gravitino.rest.RESTUtils; import org.glassfish.jersey.internal.inject.AbstractBinder; @@ -68,19 +54,10 @@ import org.mockito.Mockito; public class TestLanceNamespaceOperations extends JerseyTest { - private static class MockServletRequestFactory extends ServletRequestFactoryBase { - @Override - public HttpServletRequest get() { - HttpServletRequest request = mock(HttpServletRequest.class); - when(request.getRemoteUser()).thenReturn(null); - return request; - } - } private static NamespaceWrapper namespaceWrapper = mock(NamespaceWrapper.class); private static org.apache.gravitino.lance.common.ops.LanceNamespaceOperations namespaceOps = mock(org.apache.gravitino.lance.common.ops.LanceNamespaceOperations.class); - private static LanceTableOperations tableOps = mock(LanceTableOperations.class); @Override protected Application configure() { @@ -93,7 +70,6 @@ protected Application configure() { ResourceConfig resourceConfig = new ResourceConfig(); resourceConfig.register(LanceNamespaceOperations.class); - resourceConfig.register(org.apache.gravitino.lance.service.rest.LanceTableOperations.class); resourceConfig.register( new AbstractBinder() { @Override @@ -109,7 +85,6 @@ protected void configure() { @BeforeAll public static void setup() { when(namespaceWrapper.asNamespaceOps()).thenReturn(namespaceOps); - when(namespaceWrapper.asTableOps()).thenReturn(tableOps); } @Test @@ -340,436 +315,4 @@ public void testDropNamespace() { Assertions.assertEquals("Test exception", errorResp.getError()); Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); } - - @Test - void testCreateTable() { - String tableIds = "catalog.scheme.create_table"; - String delimiter = "."; - - // Test normal - CreateTableResponse createTableResponse = new CreateTableResponse(); - when(tableOps.createTable(any(), any(), any(), any(), any(), any())) - .thenReturn(createTableResponse); - - byte[] bytes = new byte[] {0x01, 0x02, 0x03}; - Response resp = - target(String.format("/v1/table/%s/create", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - // Test illegal argument - when(tableOps.createTable(any(), any(), any(), any(), any(), any())) - .thenThrow(new IllegalArgumentException("Illegal argument")); - - resp = - target(String.format("/v1/table/%s/create", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); - Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - // Test runtime exception - Mockito.reset(tableOps); - when(tableOps.createTable(any(), any(), any(), any(), any(), any())) - .thenThrow(new RuntimeException("Runtime exception")); - resp = - target(String.format("/v1/table/%s/create", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); - - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals("Runtime exception", errorResp.getError()); - Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); - } - - @Test - void testCreateEmptyTable() { - String tableIds = "catalog.scheme.create_empty_table"; - String delimiter = "."; - - // Test normal - CreateEmptyTableResponse createTableResponse = new CreateEmptyTableResponse(); - createTableResponse.setLocation("/path/to/table"); - createTableResponse.setProperties(ImmutableMap.of("key", "value")); - when(tableOps.createEmptyTable(any(), any(), any(), any())).thenReturn(createTableResponse); - - CreateEmptyTableRequest tableRequest = new CreateEmptyTableRequest(); - tableRequest.setLocation("/path/to/table"); - - Response resp = - target(String.format("/v1/table/%s/create-empty", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - CreateEmptyTableResponse response = resp.readEntity(CreateEmptyTableResponse.class); - Assertions.assertEquals(createTableResponse.getLocation(), response.getLocation()); - Assertions.assertEquals(createTableResponse.getProperties(), response.getProperties()); - - Mockito.reset(tableOps); - // Test illegal argument - when(tableOps.createEmptyTable(any(), any(), any(), any())) - .thenThrow(new IllegalArgumentException("Illegal argument")); - - resp = - target(String.format("/v1/table/%s/create-empty", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - // Test runtime exception - Mockito.reset(tableOps); - when(tableOps.createEmptyTable(any(), any(), any(), any())) - .thenThrow(new RuntimeException("Runtime exception")); - resp = - target(String.format("/v1/table/%s/create-empty", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals("Runtime exception", errorResp.getError()); - Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); - } - - @Test - void testRegisterTable() { - String tableIds = "catalog.scheme.register_table"; - String delimiter = "."; - - // Test normal - RegisterTableResponse registerTableResponse = new RegisterTableResponse(); - registerTableResponse.setLocation("/path/to/registered_table"); - registerTableResponse.setProperties(ImmutableMap.of("key", "value")); - when(tableOps.registerTable(any(), any(), any(), any())).thenReturn(registerTableResponse); - - RegisterTableRequest tableRequest = new RegisterTableRequest(); - tableRequest.setLocation("/path/to/registered_table"); - tableRequest.setMode(RegisterTableRequest.ModeEnum.CREATE); - - Response resp = - target(String.format("/v1/table/%s/register", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - RegisterTableResponse response = resp.readEntity(RegisterTableResponse.class); - Assertions.assertEquals(registerTableResponse.getLocation(), response.getLocation()); - Assertions.assertEquals(registerTableResponse.getProperties(), response.getProperties()); - - // Test illegal argument - Mockito.reset(tableOps); - when(tableOps.registerTable(any(), any(), any(), any())) - .thenThrow(new IllegalArgumentException("Illegal argument")); - resp = - target(String.format("/v1/table/%s/register", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - // Test runtime exception - Mockito.reset(tableOps); - when(tableOps.registerTable(any(), any(), any(), any())) - .thenThrow(new RuntimeException("Runtime exception")); - resp = - target(String.format("/v1/table/%s/register", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals("Runtime exception", errorResp.getError()); - Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); - } - - @Test - void testRegisterTableSetsRegisterPropertyToTrue() { - String tableIds = "catalog.scheme.register_table_with_property"; - String delimiter = "."; - - // Reset mock to clear any previous test state - Mockito.reset(tableOps); - - // Test that the "register" property is set to "true" - RegisterTableResponse registerTableResponse = new RegisterTableResponse(); - registerTableResponse.setLocation("/path/to/registered_table"); - registerTableResponse.setProperties(ImmutableMap.of("key", "value", "register", "true")); - when(tableOps.registerTable(any(), any(), any(), any())).thenReturn(registerTableResponse); - - RegisterTableRequest tableRequest = new RegisterTableRequest(); - tableRequest.setLocation("/path/to/registered_table"); - tableRequest.setMode(RegisterTableRequest.ModeEnum.CREATE); - tableRequest.setProperties(ImmutableMap.of("custom-key", "custom-value")); - - Response resp = - target(String.format("/v1/table/%s/register", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - - // Verify that registerTable was called with properties containing "register": "true" - Mockito.verify(tableOps) - .registerTable( - eq(tableIds), - eq(RegisterTableRequest.ModeEnum.CREATE), - eq(delimiter), - Mockito.argThat( - props -> - props != null - && "true".equals(props.get("register")) - && "/path/to/registered_table".equals(props.get("location")) - && "custom-value".equals(props.get("custom-key")))); - } - - @Test - void testDeregisterTable() { - String tableIds = "catalog.scheme.deregister_table"; - String delimiter = "."; - - DeregisterTableRequest tableRequest = new DeregisterTableRequest(); - - DeregisterTableResponse deregisterTableResponse = new DeregisterTableResponse(); - deregisterTableResponse.setLocation("/path/to/deregistered_table"); - deregisterTableResponse.setProperties(ImmutableMap.of("key", "value")); - // Test normal - when(tableOps.deregisterTable(any(), any())).thenReturn(deregisterTableResponse); - - Response resp = - target(String.format("/v1/table/%s/deregister", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - DeregisterTableResponse response = resp.readEntity(DeregisterTableResponse.class); - Assertions.assertEquals(deregisterTableResponse.getLocation(), response.getLocation()); - Assertions.assertEquals(deregisterTableResponse.getProperties(), response.getProperties()); - - // Test illegal argument - Mockito.reset(tableOps); - when(tableOps.deregisterTable(any(), any())) - .thenThrow(new IllegalArgumentException("Illegal argument")); - resp = - target(String.format("/v1/table/%s/deregister", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - // Test not found exception - Mockito.reset(tableOps); - when(tableOps.deregisterTable(any(), any())) - .thenThrow( - LanceNamespaceException.notFound( - "Table not found", "NoSuchTableException", tableIds, "")); - resp = - target(String.format("/v1/table/%s/deregister", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); - - // Test runtime exception - Mockito.reset(tableOps); - when(tableOps.deregisterTable(any(), any())) - .thenThrow(new RuntimeException("Runtime exception")); - resp = - target(String.format("/v1/table/%s/deregister", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals("Runtime exception", errorResp.getError()); - Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); - } - - @Test - void testDescribeTable() { - String tableIds = "catalog.scheme.describe_table"; - String delimiter = "."; - - // Test normal - DescribeTableResponse createTableResponse = new DescribeTableResponse(); - createTableResponse.setLocation("/path/to/describe_table"); - createTableResponse.setProperties(ImmutableMap.of("key", "value")); - when(tableOps.describeTable(any(), any(), any())).thenReturn(createTableResponse); - - DescribeTableRequest tableRequest = new DescribeTableRequest(); - Response resp = - target(String.format("/v1/table/%s/describe", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - DescribeTableResponse response = resp.readEntity(DescribeTableResponse.class); - Assertions.assertEquals(createTableResponse.getLocation(), response.getLocation()); - Assertions.assertEquals(createTableResponse.getProperties(), response.getProperties()); - - // Test not found exception - Mockito.reset(tableOps); - when(tableOps.describeTable(any(), any(), any())) - .thenThrow( - LanceNamespaceException.notFound( - "Table not found", "NoSuchTableException", tableIds, "")); - resp = - target(String.format("/v1/table/%s/describe", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); - - // Test runtime exception - Mockito.reset(tableOps); - when(tableOps.describeTable(any(), any(), any())) - .thenThrow(new RuntimeException("Runtime exception")); - resp = - target(String.format("/v1/table/%s/describe", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); - - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals("Runtime exception", errorResp.getError()); - Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); - } - - @Test - void testTableExists() { - String tableIds = "catalog.scheme.table_exists"; - String delimiter = "."; - - doReturn(true).when(tableOps).tableExists(any(), any()); - - Response resp = - target(String.format("/v1/table/%s/exists", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - - // test throw exception - doThrow( - LanceNamespaceException.notFound( - "Table not found", "NoSuchTableException", tableIds, "")) - .when(tableOps) - .tableExists(any(), any()); - resp = - target(String.format("/v1/table/%s/exists", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - - Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals(404, errorResp.getCode()); - Assertions.assertEquals("Table not found", errorResp.getError()); - Assertions.assertEquals("NoSuchTableException", errorResp.getType()); - - // Test runtime exception - Mockito.reset(tableOps); - doThrow(new RuntimeException("Runtime exception")).when(tableOps).tableExists(any(), any()); - resp = - target(String.format("/v1/table/%s/exists", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - } - - @Test - void testDropTable() { - String tableIds = "catalog.scheme.drop_table"; - String delimiter = "."; - - DropTableResponse dropTableResponse = new DropTableResponse(); - dropTableResponse.setId(Lists.newArrayList("catalog", "scheme", "drop_table")); - dropTableResponse.setProperties(ImmutableMap.of("key", "value")); - dropTableResponse.setLocation("/path/to/drop_table"); - Mockito.doReturn(dropTableResponse).when(tableOps).dropTable(any(), any()); - - Response resp = - target(String.format("/v1/table/%s/drop", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - - Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); - DropTableResponse response = resp.readEntity(DropTableResponse.class); - Assertions.assertEquals(dropTableResponse.getId(), response.getId()); - Assertions.assertEquals(dropTableResponse.getProperties(), response.getProperties()); - Assertions.assertEquals(dropTableResponse.getLocation(), response.getLocation()); - - // test throw exception - doThrow( - LanceNamespaceException.notFound( - "Table not found", "NoSuchTableException", tableIds, "")) - .when(tableOps) - .dropTable(any(), any()); - resp = - target(String.format("/v1/table/%s/drop", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - - Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - - ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); - Assertions.assertEquals(404, errorResp.getCode()); - Assertions.assertEquals("Table not found", errorResp.getError()); - Assertions.assertEquals("NoSuchTableException", errorResp.getType()); - - // Test runtime exception - Mockito.reset(tableOps); - doThrow(new RuntimeException("Runtime exception")).when(tableOps).dropTable(any(), any()); - resp = - target(String.format("/v1/table/%s/drop", tableIds)) - .queryParam("delimiter", delimiter) - .request(MediaType.APPLICATION_JSON_TYPE) - .post(null); - Assertions.assertEquals( - Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); - Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); - } } diff --git a/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceTableOperations.java b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceTableOperations.java new file mode 100644 index 00000000000..34d1b290e3e --- /dev/null +++ b/lance/lance-rest-server/src/test/java/org/apache/gravitino/lance/service/rest/TestLanceTableOperations.java @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.gravitino.lance.service.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import com.lancedb.lance.namespace.LanceNamespaceException; +import com.lancedb.lance.namespace.model.CreateEmptyTableRequest; +import com.lancedb.lance.namespace.model.CreateEmptyTableResponse; +import com.lancedb.lance.namespace.model.CreateTableIndexRequest; +import com.lancedb.lance.namespace.model.CreateTableIndexResponse; +import com.lancedb.lance.namespace.model.CreateTableResponse; +import com.lancedb.lance.namespace.model.DeregisterTableRequest; +import com.lancedb.lance.namespace.model.DeregisterTableResponse; +import com.lancedb.lance.namespace.model.DescribeTableRequest; +import com.lancedb.lance.namespace.model.DescribeTableResponse; +import com.lancedb.lance.namespace.model.ErrorResponse; +import com.lancedb.lance.namespace.model.IndexContent; +import com.lancedb.lance.namespace.model.ListTableIndicesRequest; +import com.lancedb.lance.namespace.model.ListTableIndicesResponse; +import com.lancedb.lance.namespace.model.RegisterTableRequest; +import com.lancedb.lance.namespace.model.RegisterTableResponse; +import java.io.IOException; +import java.util.List; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.Application; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.apache.gravitino.lance.common.ops.LanceTableOperations; +import org.apache.gravitino.lance.common.ops.NamespaceWrapper; +import org.apache.gravitino.rest.RESTUtils; +import org.glassfish.jersey.internal.inject.AbstractBinder; +import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.test.JerseyTest; +import org.glassfish.jersey.test.TestProperties; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class TestLanceTableOperations extends JerseyTest { + + private static NamespaceWrapper namespaceWrapper = mock(NamespaceWrapper.class); + private static org.apache.gravitino.lance.common.ops.LanceTableOperations tableOps = + mock(LanceTableOperations.class); + + @Override + protected Application configure() { + try { + forceSet( + TestProperties.CONTAINER_PORT, String.valueOf(RESTUtils.findAvailablePort(2000, 3000))); + } catch (IOException e) { + throw new RuntimeException(e); + } + + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(org.apache.gravitino.lance.service.rest.LanceTableOperations.class); + resourceConfig.register( + new AbstractBinder() { + @Override + protected void configure() { + bind(namespaceWrapper).to(NamespaceWrapper.class).ranked(2); + bindFactory(MockServletRequestFactory.class).to(HttpServletRequest.class); + } + }); + + return resourceConfig; + } + + @BeforeAll + public static void setup() { + when(namespaceWrapper.asTableOps()).thenReturn(tableOps); + } + + @Test + void testCreateTable() { + String tableIds = "catalog.scheme.create_table"; + String delimiter = "."; + + // Test normal + CreateTableResponse createTableResponse = new CreateTableResponse(); + when(tableOps.createTable(any(), any(), any(), any(), any(), any())) + .thenReturn(createTableResponse); + + byte[] bytes = new byte[] {0x01, 0x02, 0x03}; + Response resp = + target(String.format("/v1/table/%s/create", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test illegal argument + when(tableOps.createTable(any(), any(), any(), any(), any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + + resp = + target(String.format("/v1/table/%s/create", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.createTable(any(), any(), any(), any(), any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/create", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(bytes, "application/vnd.apache.arrow.stream")); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testCreateEmptyTable() { + String tableIds = "catalog.scheme.create_empty_table"; + String delimiter = "."; + + // Test normal + CreateEmptyTableResponse createTableResponse = new CreateEmptyTableResponse(); + createTableResponse.setLocation("/path/to/table"); + createTableResponse.setProperties(ImmutableMap.of("key", "value")); + when(tableOps.createEmptyTable(any(), any(), any(), any())).thenReturn(createTableResponse); + + CreateEmptyTableRequest tableRequest = new CreateEmptyTableRequest(); + tableRequest.setLocation("/path/to/table"); + + Response resp = + target(String.format("/v1/table/%s/create-empty", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + CreateEmptyTableResponse response = resp.readEntity(CreateEmptyTableResponse.class); + Assertions.assertEquals(createTableResponse.getLocation(), response.getLocation()); + Assertions.assertEquals(createTableResponse.getProperties(), response.getProperties()); + + Mockito.reset(tableOps); + // Test illegal argument + when(tableOps.createEmptyTable(any(), any(), any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + + resp = + target(String.format("/v1/table/%s/create-empty", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.createEmptyTable(any(), any(), any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/create-empty", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testRegisterTable() { + String tableIds = "catalog.scheme.register_table"; + String delimiter = "."; + + // Test normal + RegisterTableResponse registerTableResponse = new RegisterTableResponse(); + registerTableResponse.setLocation("/path/to/registered_table"); + registerTableResponse.setProperties(ImmutableMap.of("key", "value")); + when(tableOps.registerTable(any(), any(), any(), any())).thenReturn(registerTableResponse); + + RegisterTableRequest tableRequest = new RegisterTableRequest(); + tableRequest.setLocation("/path/to/registered_table"); + tableRequest.setMode(RegisterTableRequest.ModeEnum.CREATE); + + Response resp = + target(String.format("/v1/table/%s/register", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + RegisterTableResponse response = resp.readEntity(RegisterTableResponse.class); + Assertions.assertEquals(registerTableResponse.getLocation(), response.getLocation()); + Assertions.assertEquals(registerTableResponse.getProperties(), response.getProperties()); + + // Test illegal argument + Mockito.reset(tableOps); + when(tableOps.registerTable(any(), any(), any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + resp = + target(String.format("/v1/table/%s/register", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.registerTable(any(), any(), any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/register", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testDeregisterTable() { + String tableIds = "catalog.scheme.deregister_table"; + String delimiter = "."; + + DeregisterTableRequest tableRequest = new DeregisterTableRequest(); + + DeregisterTableResponse deregisterTableResponse = new DeregisterTableResponse(); + deregisterTableResponse.setLocation("/path/to/deregistered_table"); + deregisterTableResponse.setProperties(ImmutableMap.of("key", "value")); + // Test normal + when(tableOps.deregisterTable(any(), any())).thenReturn(deregisterTableResponse); + + Response resp = + target(String.format("/v1/table/%s/deregister", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + DeregisterTableResponse response = resp.readEntity(DeregisterTableResponse.class); + Assertions.assertEquals(deregisterTableResponse.getLocation(), response.getLocation()); + Assertions.assertEquals(deregisterTableResponse.getProperties(), response.getProperties()); + + // Test illegal argument + Mockito.reset(tableOps); + when(tableOps.deregisterTable(any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + resp = + target(String.format("/v1/table/%s/deregister", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test not found exception + Mockito.reset(tableOps); + when(tableOps.deregisterTable(any(), any())) + .thenThrow( + LanceNamespaceException.notFound( + "Table not found", "NoSuchTableException", tableIds, "")); + resp = + target(String.format("/v1/table/%s/deregister", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.deregisterTable(any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/deregister", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testDescribeTable() { + String tableIds = "catalog.scheme.describe_table"; + String delimiter = "."; + + // Test normal + DescribeTableResponse createTableResponse = new DescribeTableResponse(); + createTableResponse.setLocation("/path/to/describe_table"); + createTableResponse.setProperties(ImmutableMap.of("key", "value")); + when(tableOps.describeTable(any(), any(), any())).thenReturn(createTableResponse); + + DescribeTableRequest tableRequest = new DescribeTableRequest(); + Response resp = + target(String.format("/v1/table/%s/describe", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + DescribeTableResponse response = resp.readEntity(DescribeTableResponse.class); + Assertions.assertEquals(createTableResponse.getLocation(), response.getLocation()); + Assertions.assertEquals(createTableResponse.getProperties(), response.getProperties()); + + // Test not found exception + Mockito.reset(tableOps); + when(tableOps.describeTable(any(), any(), any())) + .thenThrow( + LanceNamespaceException.notFound( + "Table not found", "NoSuchTableException", tableIds, "")); + resp = + target(String.format("/v1/table/%s/describe", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.describeTable(any(), any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/describe", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testCreateTableIndex() { + String tableIds = "catalog.scheme.to_create_index_table"; + String delimiter = "."; + + // Test normal + CreateTableIndexRequest tableRequest = new CreateTableIndexRequest(); + + CreateTableIndexResponse response = new CreateTableIndexResponse(); + response.setProperties(ImmutableMap.of("key", "value")); + when(tableOps.createTableIndex(any(), any(), any())).thenReturn(response); + + Response resp = + target(String.format("/v1/table/%s/create_index", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + response = resp.readEntity(CreateTableIndexResponse.class); + Assertions.assertEquals(response.getProperties(), response.getProperties()); + Assertions.assertEquals("value", response.getProperties().get("key")); + + Mockito.reset(tableOps); + // Test illegal argument + when(tableOps.createTableIndex(any(), any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + + resp = + target(String.format("/v1/table/%s/create_index", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test runtime exception + Mockito.reset(tableOps); + when(tableOps.createTableIndex(any(), any(), any())) + .thenThrow(new RuntimeException("Runtime exception")); + resp = + target(String.format("/v1/table/%s/create_index", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals( + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ErrorResponse errorResp = resp.readEntity(ErrorResponse.class); + Assertions.assertEquals("Runtime exception", errorResp.getError()); + Assertions.assertEquals(RuntimeException.class.getSimpleName(), errorResp.getType()); + } + + @Test + void testListTableIndices() { + String tableIds = "catalog.scheme.to_list_index_table"; + String delimiter = "."; + + ListTableIndicesRequest tableRequest = new ListTableIndicesRequest(); + + ListTableIndicesResponse response = new ListTableIndicesResponse(); + IndexContent indexContent = new IndexContent(); + indexContent.setIndexName("test_index"); + indexContent.setColumns(List.of("col1")); + response.setIndexes(List.of(indexContent)); + when(tableOps.listTableIndices(any(), any(), any())).thenReturn(response); + + Response resp = + target(String.format("/v1/table/%s/index/list", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + ListTableIndicesResponse actualResponse = resp.readEntity(ListTableIndicesResponse.class); + Assertions.assertEquals(1, actualResponse.getIndexes().size()); + Assertions.assertEquals("test_index", actualResponse.getIndexes().get(0).getIndexName()); + Assertions.assertEquals(List.of("col1"), actualResponse.getIndexes().get(0).getColumns()); + + Mockito.reset(tableOps); + + // Test illegal argument + when(tableOps.listTableIndices(any(), any(), any())) + .thenThrow(new IllegalArgumentException("Illegal argument")); + resp = + target(String.format("/v1/table/%s/index/list", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), resp.getStatus()); + Assertions.assertEquals(MediaType.APPLICATION_JSON_TYPE, resp.getMediaType()); + + // Test not found exception + Mockito.reset(tableOps); + when(tableOps.listTableIndices(any(), any(), any())) + .thenThrow( + LanceNamespaceException.notFound( + "Table not found", "NoSuchTableException", tableIds, "")); + resp = + target(String.format("/v1/table/%s/index/list", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), resp.getStatus()); + } + + @Test + void testRegisterTableSetsRegisterPropertyToTrue() { + String tableIds = "catalog.scheme.register_table_with_property"; + String delimiter = "."; + + // Reset mock to clear any previous test state + Mockito.reset(tableOps); + + // Test that the "register" property is set to "true" + RegisterTableResponse registerTableResponse = new RegisterTableResponse(); + registerTableResponse.setLocation("/path/to/registered_table"); + registerTableResponse.setProperties(ImmutableMap.of("key", "value", "register", "true")); + when(tableOps.registerTable(any(), any(), any(), any())).thenReturn(registerTableResponse); + + RegisterTableRequest tableRequest = new RegisterTableRequest(); + tableRequest.setLocation("/path/to/registered_table"); + tableRequest.setMode(RegisterTableRequest.ModeEnum.CREATE); + tableRequest.setProperties(ImmutableMap.of("custom-key", "custom-value")); + + Response resp = + target(String.format("/v1/table/%s/register", tableIds)) + .queryParam("delimiter", delimiter) + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.entity(tableRequest, MediaType.APPLICATION_JSON_TYPE)); + + Assertions.assertEquals(Response.Status.OK.getStatusCode(), resp.getStatus()); + + // Verify that registerTable was called with properties containing "register": "true" + Mockito.verify(tableOps) + .registerTable( + eq(tableIds), + eq(RegisterTableRequest.ModeEnum.CREATE), + eq(delimiter), + Mockito.argThat( + props -> + props != null + && "true".equals(props.get("register")) + && "/path/to/registered_table".equals(props.get("location")) + && "custom-value".equals(props.get("custom-key")))); + } +} diff --git a/scripts/mysql/upgrade-1.0.0-to-1.1.0-mysql.sql b/scripts/mysql/upgrade-1.0.0-to-1.1.0-mysql.sql index f40205419f3..a883bafcd39 100644 --- a/scripts/mysql/upgrade-1.0.0-to-1.1.0-mysql.sql +++ b/scripts/mysql/upgrade-1.0.0-to-1.1.0-mysql.sql @@ -29,4 +29,4 @@ CREATE TABLE IF NOT EXISTS `table_version_info` ( `version` BIGINT(20) UNSIGNED COMMENT 'table current version', `deleted_at` BIGINT(20) UNSIGNED DEFAULT 0 COMMENT 'table deletion timestamp, 0 means not deleted', UNIQUE KEY `uk_table_id_version_deleted_at` (`table_id`, `version`, `deleted_at`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin COMMENT 'table detail information including format, location, properties, partition, distribution, sort order, index and so on'; +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin COMMENT 'table detail information including format, location, properties, partition, distribution, sort order, index and so on'; \ No newline at end of file