diff --git a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/PolarisSparkCatalog.java b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/PolarisSparkCatalog.java index 771c191c05..36ed872d3b 100644 --- a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/PolarisSparkCatalog.java +++ b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/PolarisSparkCatalog.java @@ -71,7 +71,12 @@ public Table loadTable(Identifier identifier) throws NoSuchTableException { try { GenericTable genericTable = this.polarisCatalog.loadGenericTable(Spark3Util.identifierToTableIdentifier(identifier)); - return PolarisCatalogUtils.loadSparkTable(genericTable); + // Currently Hudi supports Spark Datasource V1, therefore we return a V1Table + if (PolarisCatalogUtils.useHudi(genericTable.getFormat())) { + return PolarisCatalogUtils.loadV1SparkTable(genericTable, identifier, name()); + } else { + return PolarisCatalogUtils.loadV2SparkTable(genericTable); + } } catch (org.apache.iceberg.exceptions.NoSuchTableException e) { throw new NoSuchTableException(identifier); } @@ -111,7 +116,12 @@ public Table createTable( baseLocation, null, properties); - return PolarisCatalogUtils.loadSparkTable(genericTable); + // Currently Hudi supports Spark Datasource V1, therefore we return a V1Table + if (PolarisCatalogUtils.useHudi(format)) { + return PolarisCatalogUtils.loadV1SparkTable(genericTable, identifier, name()); + } else { + return PolarisCatalogUtils.loadV2SparkTable(genericTable); + } } catch (AlreadyExistsException e) { throw new TableAlreadyExistsException(identifier); } diff --git a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/SparkCatalog.java b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/SparkCatalog.java index 26c1fbbf3d..040638a479 100644 --- a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/SparkCatalog.java +++ b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/SparkCatalog.java @@ -30,6 +30,7 @@ import org.apache.iceberg.spark.SupportsReplaceView; import org.apache.iceberg.util.PropertyUtil; import org.apache.polaris.spark.utils.DeltaHelper; +import org.apache.polaris.spark.utils.HudiHelper; import org.apache.polaris.spark.utils.PolarisCatalogUtils; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; @@ -69,6 +70,7 @@ public class SparkCatalog @VisibleForTesting protected org.apache.iceberg.spark.SparkCatalog icebergsSparkCatalog = null; @VisibleForTesting protected PolarisSparkCatalog polarisSparkCatalog = null; @VisibleForTesting protected DeltaHelper deltaHelper = null; + @VisibleForTesting protected HudiHelper hudiHelper = null; @Override public String name() { @@ -130,6 +132,7 @@ public void initialize(String name, CaseInsensitiveStringMap options) { this.catalogName = name; initRESTCatalog(name, options); this.deltaHelper = new DeltaHelper(options); + this.hudiHelper = new HudiHelper(options); } @Override @@ -154,12 +157,16 @@ public Table createTable( throw new UnsupportedOperationException( "Create table without location key is not supported by Polaris. Please provide location or path on table creation."); } - if (PolarisCatalogUtils.useDelta(provider)) { // For delta table, we load the delta catalog to help dealing with the // delta log creation. TableCatalog deltaCatalog = deltaHelper.loadDeltaCatalog(this.polarisSparkCatalog); return deltaCatalog.createTable(ident, schema, transforms, properties); + } else if (PolarisCatalogUtils.useHudi(provider)) { + // For creating the hudi table, we load HoodieCatalog + // to create the .hoodie folder in cloud storage + TableCatalog hudiCatalog = hudiHelper.loadHudiCatalog(this.polarisSparkCatalog); + return hudiCatalog.createTable(ident, schema, transforms, properties); } else { return this.polarisSparkCatalog.createTable(ident, schema, transforms, properties); } @@ -180,8 +187,12 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT // using ALTER TABLE ...SET LOCATION, and ALTER TABLE ... SET FILEFORMAT. TableCatalog deltaCatalog = deltaHelper.loadDeltaCatalog(this.polarisSparkCatalog); return deltaCatalog.alterTable(ident, changes); + } else if (PolarisCatalogUtils.useHudi(provider)) { + TableCatalog hudiCatalog = hudiHelper.loadHudiCatalog(this.polarisSparkCatalog); + return hudiCatalog.alterTable(ident, changes); + } else { + return this.polarisSparkCatalog.alterTable(ident); } - return this.polarisSparkCatalog.alterTable(ident); } } diff --git a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/HudiHelper.java b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/HudiHelper.java new file mode 100644 index 0000000000..105bad4c08 --- /dev/null +++ b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/HudiHelper.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.polaris.spark.utils; + +import org.apache.iceberg.common.DynConstructors; +import org.apache.polaris.spark.PolarisSparkCatalog; +import org.apache.spark.sql.connector.catalog.DelegatingCatalogExtension; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class HudiHelper { + public static final String HUDI_CATALOG_IMPL_KEY = "hudi-catalog-impl"; + private static final String DEFAULT_HUDI_CATALOG_CLASS = + "org.apache.spark.sql.hudi.catalog.HoodieCatalog"; + + private TableCatalog hudiCatalog = null; + private String hudiCatalogImpl = DEFAULT_HUDI_CATALOG_CLASS; + + public HudiHelper(CaseInsensitiveStringMap options) { + if (options.get(HUDI_CATALOG_IMPL_KEY) != null) { + this.hudiCatalogImpl = options.get(HUDI_CATALOG_IMPL_KEY); + } + } + + public TableCatalog loadHudiCatalog(PolarisSparkCatalog polarisSparkCatalog) { + if (this.hudiCatalog != null) { + return this.hudiCatalog; + } + + DynConstructors.Ctor ctor; + try { + ctor = DynConstructors.builder(TableCatalog.class).impl(hudiCatalogImpl).buildChecked(); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException( + String.format("Cannot initialize Hudi Catalog %s: %s", hudiCatalogImpl, e.getMessage()), + e); + } + + try { + this.hudiCatalog = ctor.newInstance(); + } catch (ClassCastException e) { + throw new IllegalArgumentException( + String.format( + "Cannot initialize Hudi Catalog, %s does not implement TableCatalog.", + hudiCatalogImpl), + e); + } + + // set the polaris spark catalog as the delegate catalog of hudi catalog + // will be used in HoodieCatalog's loadTable + ((DelegatingCatalogExtension) this.hudiCatalog).setDelegateCatalog(polarisSparkCatalog); + return this.hudiCatalog; + } +} diff --git a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/PolarisCatalogUtils.java b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/PolarisCatalogUtils.java index 98016b71fd..5493f0dc36 100644 --- a/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/PolarisCatalogUtils.java +++ b/plugins/spark/v3.5/spark/src/main/java/org/apache/polaris/spark/utils/PolarisCatalogUtils.java @@ -29,14 +29,20 @@ import org.apache.iceberg.spark.SparkCatalog; import org.apache.polaris.spark.rest.GenericTable; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.catalog.CatalogTable; +import org.apache.spark.sql.catalyst.catalog.CatalogTableType; +import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.execution.datasources.DataSource; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; public class PolarisCatalogUtils { + public static final String TABLE_PROVIDER_KEY = "provider"; public static final String TABLE_PATH_KEY = "path"; @@ -50,6 +56,10 @@ public static boolean useDelta(String provider) { return "delta".equalsIgnoreCase(provider); } + public static boolean useHudi(String provider) { + return "hudi".equalsIgnoreCase(provider); + } + /** * For tables whose location is managed by Spark Session Catalog, there will be no location or * path in the properties. @@ -61,16 +71,11 @@ public static boolean isTableWithSparkManagedLocation(Map proper } /** - * Load spark table using DataSourceV2. - * - * @return V2Table if DataSourceV2 is available for the table format. For delta table, it returns - * DeltaTableV2. + * Normalize table properties for loading Spark tables by ensuring the TABLE_PATH_KEY is properly + * set. DataSourceV2 requires the path property on table loading. */ - public static Table loadSparkTable(GenericTable genericTable) { - SparkSession sparkSession = SparkSession.active(); - TableProvider provider = - DataSource.lookupDataSourceV2(genericTable.getFormat(), sparkSession.sessionState().conf()) - .get(); + private static Map normalizeTablePropertiesForLoadSparkTable( + GenericTable genericTable) { Map properties = genericTable.getProperties(); boolean hasLocationClause = properties.get(TableCatalog.PROP_LOCATION) != null; boolean hasPathClause = properties.get(TABLE_PATH_KEY) != null; @@ -87,10 +92,80 @@ public static Table loadSparkTable(GenericTable genericTable) { tableProperties.put(TABLE_PATH_KEY, properties.get(TableCatalog.PROP_LOCATION)); } } + return tableProperties; + } + + /** + * Load spark table using DataSourceV2. + * + * @return V2Table if DataSourceV2 is available for the table format. For delta table, it returns + * DeltaTableV2. + */ + public static Table loadV2SparkTable(GenericTable genericTable) { + SparkSession sparkSession = SparkSession.active(); + TableProvider provider = + DataSource.lookupDataSourceV2(genericTable.getFormat(), sparkSession.sessionState().conf()) + .get(); + Map tableProperties = normalizeTablePropertiesForLoadSparkTable(genericTable); return DataSourceV2Utils.getTableFromProvider( provider, new CaseInsensitiveStringMap(tableProperties), scala.Option.empty()); } + /** + * Return a Spark V1Table for formats that do not use DataSourceV2. Currently, this is being used + * for Hudi tables + */ + public static Table loadV1SparkTable( + GenericTable genericTable, Identifier identifier, String catalogName) { + Map tableProperties = normalizeTablePropertiesForLoadSparkTable(genericTable); + + // Need full identifier in order to construct CatalogTable + String namespacePath = String.join(".", identifier.namespace()); + TableIdentifier tableIdentifier = + new TableIdentifier( + identifier.name(), Option.apply(namespacePath), Option.apply(catalogName)); + + scala.collection.immutable.Map scalaOptions = + (scala.collection.immutable.Map) + scala.collection.immutable.Map$.MODULE$.apply( + scala.collection.JavaConverters.mapAsScalaMap(tableProperties).toSeq()); + + org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat storage = + DataSource.buildStorageFormatFromOptions(scalaOptions); + + // Currently Polaris generic table does not contain any schema information, partition columns, + // stats, etc + // for now we will fill the parameters we have from polaris catalog, and let underlying client + // resolve the rest within its catalog implementation + org.apache.spark.sql.types.StructType emptySchema = new org.apache.spark.sql.types.StructType(); + scala.collection.immutable.Seq emptyStringSeq = + scala.collection.JavaConverters.asScalaBuffer(new java.util.ArrayList()).toList(); + CatalogTable catalogTable = + new CatalogTable( + tableIdentifier, + CatalogTableType.EXTERNAL(), + storage, + emptySchema, + Option.apply(genericTable.getFormat()), + emptyStringSeq, + scala.Option.empty(), + genericTable.getProperties().get(TableCatalog.PROP_OWNER), + System.currentTimeMillis(), + -1L, + "", + scalaOptions, + scala.Option.empty(), + scala.Option.empty(), + scala.Option.empty(), + emptyStringSeq, + false, + true, + scala.collection.immutable.Map$.MODULE$.empty(), + scala.Option.empty()); + + return new org.apache.spark.sql.connector.catalog.V1Table(catalogTable); + } + /** * Get the catalogAuth field inside the RESTSessionCatalog used by Iceberg Spark Catalog use * reflection. TODO: Deprecate this function once the iceberg client is updated to 1.9.0 to use diff --git a/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/NoopHudiCatalog.java b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/NoopHudiCatalog.java new file mode 100644 index 0000000000..93862ea3c1 --- /dev/null +++ b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/NoopHudiCatalog.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.spark; + +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.DelegatingCatalogExtension; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableChange; + +/** + * This is a fake hudi catalog class that is used for testing. This class is a noop class that + * directly passes all calls to the delegate CatalogPlugin configured as part of + * DelegatingCatalogExtension. + */ +public class NoopHudiCatalog extends DelegatingCatalogExtension { + + @Override + public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchTableException { + return super.loadTable(ident); + } +} diff --git a/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/SparkCatalogTest.java b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/SparkCatalogTest.java index 6aa4a3c089..125c6d1d59 100644 --- a/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/SparkCatalogTest.java +++ b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/SparkCatalogTest.java @@ -35,6 +35,7 @@ import org.apache.iceberg.spark.actions.SparkActions; import org.apache.iceberg.spark.source.SparkTable; import org.apache.polaris.spark.utils.DeltaHelper; +import org.apache.polaris.spark.utils.HudiHelper; import org.apache.polaris.spark.utils.PolarisCatalogUtils; import org.apache.spark.SparkContext; import org.apache.spark.sql.SparkSession; @@ -58,6 +59,7 @@ import org.apache.spark.sql.internal.SessionState; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -67,6 +69,9 @@ import scala.Option; public class SparkCatalogTest { + private static MockedStatic mockedStaticSparkSession; + private static SparkSession mockedSession; + private static class InMemoryIcebergSparkCatalog extends org.apache.iceberg.spark.SparkCatalog { private PolarisInMemoryCatalog inMemoryCatalog = null; @@ -104,6 +109,7 @@ public void initialize(String name, CaseInsensitiveStringMap options) { this.polarisSparkCatalog.initialize(name, options); this.deltaHelper = new DeltaHelper(options); + this.hudiHelper = new HudiHelper(options); } } @@ -122,25 +128,50 @@ public void setup() throws Exception { catalogConfig.put("cache-enabled", "false"); catalogConfig.put( DeltaHelper.DELTA_CATALOG_IMPL_KEY, "org.apache.polaris.spark.NoopDeltaCatalog"); + catalogConfig.put(HudiHelper.HUDI_CATALOG_IMPL_KEY, "org.apache.polaris.spark.NoopHudiCatalog"); catalog = new InMemorySparkCatalog(); Configuration conf = new Configuration(); - try (MockedStatic mockedStaticSparkSession = - Mockito.mockStatic(SparkSession.class); - MockedStatic mockedSparkUtil = Mockito.mockStatic(SparkUtil.class)) { - SparkSession mockedSession = Mockito.mock(SparkSession.class); - mockedStaticSparkSession.when(SparkSession::active).thenReturn(mockedSession); + + // Setup persistent SparkSession mock + mockedStaticSparkSession = Mockito.mockStatic(SparkSession.class); + mockedSession = Mockito.mock(SparkSession.class); + org.apache.spark.sql.RuntimeConfig mockedConfig = + Mockito.mock(org.apache.spark.sql.RuntimeConfig.class); + SparkContext mockedContext = Mockito.mock(SparkContext.class); + SessionState mockedSessionState = Mockito.mock(SessionState.class); + SQLConf mockedSQLConf = Mockito.mock(SQLConf.class); + + mockedStaticSparkSession.when(SparkSession::active).thenReturn(mockedSession); + Mockito.when(mockedSession.conf()).thenReturn(mockedConfig); + Mockito.when(mockedSession.sessionState()).thenReturn(mockedSessionState); + Mockito.when(mockedSessionState.conf()).thenReturn(mockedSQLConf); + Mockito.when(mockedConfig.get("spark.sql.extensions", null)) + .thenReturn( + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions," + + "io.delta.sql.DeltaSparkSessionExtension" + + "org.apache.spark.sql.hudi.HoodieSparkSessionExtension"); + Mockito.when(mockedConfig.get("spark.sql.warehouse.dir", "spark-warehouse")) + .thenReturn("/tmp/test-warehouse"); + Mockito.when(mockedSession.sparkContext()).thenReturn(mockedContext); + Mockito.when(mockedContext.applicationId()).thenReturn("appId"); + Mockito.when(mockedContext.sparkUser()).thenReturn("test-user"); + Mockito.when(mockedContext.version()).thenReturn("3.5"); + + try (MockedStatic mockedSparkUtil = Mockito.mockStatic(SparkUtil.class)) { mockedSparkUtil .when(() -> SparkUtil.hadoopConfCatalogOverrides(mockedSession, catalogName)) .thenReturn(conf); - SparkContext mockedContext = Mockito.mock(SparkContext.class); - Mockito.when(mockedSession.sparkContext()).thenReturn(mockedContext); - Mockito.when(mockedContext.applicationId()).thenReturn("appId"); - Mockito.when(mockedContext.sparkUser()).thenReturn("test-user"); - Mockito.when(mockedContext.version()).thenReturn("3.5"); catalog.initialize(catalogName, new CaseInsensitiveStringMap(catalogConfig)); + catalog.createNamespace(defaultNS, Maps.newHashMap()); + } + } + + @AfterEach + public void tearDown() { + if (mockedStaticSparkSession != null) { + mockedStaticSparkSession.close(); } - catalog.createNamespace(defaultNS, Maps.newHashMap()); } @Test @@ -402,7 +433,7 @@ void testIcebergTableOperations() throws Exception { } @ParameterizedTest - @ValueSource(strings = {"delta", "csv"}) + @ValueSource(strings = {"delta", "hudi", "csv"}) void testCreateAndLoadGenericTable(String format) throws Exception { Identifier identifier = Identifier.of(defaultNS, "generic-test-table"); createAndValidateGenericTableWithLoad(catalog, identifier, defaultSchema, format); @@ -418,7 +449,6 @@ void testCreateAndLoadGenericTable(String format) throws Exception { () -> catalog.createTable(identifier, defaultSchema, new Transform[0], newProperties)) .isInstanceOf(TableAlreadyExistsException.class); - // drop the iceberg table catalog.dropTable(identifier); assertThatThrownBy(() -> catalog.loadTable(identifier)) .isInstanceOf(NoSuchTableException.class); @@ -428,8 +458,9 @@ void testCreateAndLoadGenericTable(String format) throws Exception { @Test void testMixedTables() throws Exception { // create two iceberg tables, and three non-iceberg tables - String[] tableNames = new String[] {"iceberg1", "iceberg2", "delta1", "csv1", "delta2"}; - String[] tableFormats = new String[] {"iceberg", null, "delta", "csv", "delta"}; + String[] tableNames = + new String[] {"iceberg1", "iceberg2", "delta1", "csv1", "delta2", "hudi1", "hudi2"}; + String[] tableFormats = new String[] {"iceberg", null, "delta", "csv", "delta", "hudi", "hudi"}; for (int i = 0; i < tableNames.length; i++) { Identifier identifier = Identifier.of(defaultNS, tableNames[i]); createAndValidateGenericTableWithLoad(catalog, identifier, defaultSchema, tableFormats[i]); @@ -445,8 +476,9 @@ void testMixedTables() throws Exception { // drop iceberg2 and delta1 table catalog.dropTable(Identifier.of(defaultNS, "iceberg2")); catalog.dropTable(Identifier.of(defaultNS, "delta2")); + catalog.dropTable(Identifier.of(defaultNS, "hudi2")); - String[] remainingTableNames = new String[] {"iceberg1", "delta1", "csv1"}; + String[] remainingTableNames = new String[] {"iceberg1", "delta1", "csv1", "hudi1"}; Identifier[] remainingTableIndents = catalog.listTables(defaultNS); assertThat(remainingTableIndents.length).isEqualTo(remainingTableNames.length); for (String name : remainingTableNames) { @@ -465,12 +497,15 @@ void testAlterAndRenameTable() throws Exception { String icebergTableName = "iceberg-table"; String deltaTableName = "delta-table"; String csvTableName = "csv-table"; + String hudiTableName = "hudi-table"; Identifier icebergIdent = Identifier.of(defaultNS, icebergTableName); Identifier deltaIdent = Identifier.of(defaultNS, deltaTableName); Identifier csvIdent = Identifier.of(defaultNS, csvTableName); + Identifier hudiIdent = Identifier.of(defaultNS, hudiTableName); createAndValidateGenericTableWithLoad(catalog, icebergIdent, defaultSchema, "iceberg"); createAndValidateGenericTableWithLoad(catalog, deltaIdent, defaultSchema, "delta"); createAndValidateGenericTableWithLoad(catalog, csvIdent, defaultSchema, "csv"); + createAndValidateGenericTableWithLoad(catalog, hudiIdent, defaultSchema, "hudi"); // verify alter iceberg table Table newIcebergTable = @@ -488,17 +523,18 @@ void testAlterAndRenameTable() throws Exception { // verify alter delta table is a no-op, and alter csv table throws an exception SQLConf conf = new SQLConf(); - try (MockedStatic mockedStaticSparkSession = - Mockito.mockStatic(SparkSession.class); - MockedStatic mockedStaticDS = Mockito.mockStatic(DataSource.class); + try (MockedStatic mockedStaticDS = Mockito.mockStatic(DataSource.class); MockedStatic mockedStaticDSV2 = Mockito.mockStatic(DataSourceV2Utils.class)) { - SparkSession mockedSession = Mockito.mock(SparkSession.class); - mockedStaticSparkSession.when(SparkSession::active).thenReturn(mockedSession); SessionState mockedState = Mockito.mock(SessionState.class); Mockito.when(mockedSession.sessionState()).thenReturn(mockedState); Mockito.when(mockedState.conf()).thenReturn(conf); + // Mock SessionCatalog for Hudi support + org.apache.spark.sql.catalyst.catalog.SessionCatalog mockedSessionCatalog = + Mockito.mock(org.apache.spark.sql.catalyst.catalog.SessionCatalog.class); + Mockito.when(mockedState.catalog()).thenReturn(mockedSessionCatalog); + TableProvider deltaProvider = Mockito.mock(TableProvider.class); mockedStaticDS .when(() -> DataSource.lookupDataSourceV2(Mockito.eq("delta"), Mockito.any())) @@ -551,18 +587,21 @@ void testAlterAndRenameTable() throws Exception { void testPurgeInvalidateTable() throws Exception { Identifier icebergIdent = Identifier.of(defaultNS, "iceberg-table"); Identifier deltaIdent = Identifier.of(defaultNS, "delta-table"); + Identifier hudiIdent = Identifier.of(defaultNS, "hudi-table"); createAndValidateGenericTableWithLoad(catalog, icebergIdent, defaultSchema, "iceberg"); createAndValidateGenericTableWithLoad(catalog, deltaIdent, defaultSchema, "delta"); - + createAndValidateGenericTableWithLoad(catalog, hudiIdent, defaultSchema, "hudi"); // test invalidate table is a no op today catalog.invalidateTable(icebergIdent); catalog.invalidateTable(deltaIdent); + catalog.invalidateTable(hudiIdent); Identifier[] tableIdents = catalog.listTables(defaultNS); - assertThat(tableIdents.length).isEqualTo(2); + assertThat(tableIdents.length).isEqualTo(3); // verify purge tables drops the table catalog.purgeTable(deltaIdent); + catalog.purgeTable(hudiIdent); assertThat(catalog.listTables(defaultNS).length).isEqualTo(1); // purge iceberg table triggers file deletion @@ -588,42 +627,60 @@ private void createAndValidateGenericTableWithLoad( properties.put(PolarisCatalogUtils.TABLE_PROVIDER_KEY, format); properties.put( TableCatalog.PROP_LOCATION, - String.format("file:///tmp/delta/path/to/table/%s/", identifier.name())); - - SQLConf conf = new SQLConf(); - try (MockedStatic mockedStaticSparkSession = - Mockito.mockStatic(SparkSession.class); - MockedStatic mockedStaticDS = Mockito.mockStatic(DataSource.class); - MockedStatic mockedStaticDSV2 = - Mockito.mockStatic(DataSourceV2Utils.class)) { - SparkSession mockedSession = Mockito.mock(SparkSession.class); - mockedStaticSparkSession.when(SparkSession::active).thenReturn(mockedSession); - SessionState mockedState = Mockito.mock(SessionState.class); - Mockito.when(mockedSession.sessionState()).thenReturn(mockedState); - Mockito.when(mockedState.conf()).thenReturn(conf); + String.format("file:///tmp/%s/path/to/table/%s/", format, identifier.name())); - TableProvider provider = Mockito.mock(TableProvider.class); - mockedStaticDS - .when(() -> DataSource.lookupDataSourceV2(Mockito.eq(format), Mockito.any())) - .thenReturn(Option.apply(provider)); - V1Table table = Mockito.mock(V1Table.class); - mockedStaticDSV2 - .when( - () -> - DataSourceV2Utils.getTableFromProvider( - Mockito.eq(provider), Mockito.any(), Mockito.any())) - .thenReturn(table); + if (PolarisCatalogUtils.useIceberg(format)) { Table createdTable = sparkCatalog.createTable(identifier, schema, new Transform[0], properties); Table loadedTable = sparkCatalog.loadTable(identifier); - // verify the create and load table result - if (PolarisCatalogUtils.useIceberg(format)) { - // iceberg SparkTable is returned for iceberg tables - assertThat(createdTable).isInstanceOf(SparkTable.class); - assertThat(loadedTable).isInstanceOf(SparkTable.class); - } else { - // Spark V1 table is returned for non-iceberg tables + // verify iceberg SparkTable is returned for iceberg tables + assertThat(createdTable).isInstanceOf(SparkTable.class); + assertThat(loadedTable).isInstanceOf(SparkTable.class); + } else { + // For non-Iceberg tables, use mocking + try (MockedStatic mockedStaticDS = Mockito.mockStatic(DataSource.class); + MockedStatic mockedStaticDSV2 = + Mockito.mockStatic(DataSourceV2Utils.class); + MockedStatic mockedStaticUtils = + Mockito.mockStatic(PolarisCatalogUtils.class)) { + + V1Table table = Mockito.mock(V1Table.class); + + // Mock the routing utility methods + mockedStaticUtils + .when(() -> PolarisCatalogUtils.useHudi(Mockito.eq(format))) + .thenCallRealMethod(); + + if ("hudi".equalsIgnoreCase(format)) { + // For Hudi tables, mock the loadV1SparkHudiTable method to return the mock table + mockedStaticUtils + .when( + () -> + PolarisCatalogUtils.loadV1SparkTable( + Mockito.any(), Mockito.any(), Mockito.any())) + .thenReturn(table); + } else { + TableProvider provider = Mockito.mock(TableProvider.class); + mockedStaticDS + .when(() -> DataSource.lookupDataSourceV2(Mockito.eq(format), Mockito.any())) + .thenReturn(Option.apply(provider)); + mockedStaticDSV2 + .when( + () -> + DataSourceV2Utils.getTableFromProvider( + Mockito.eq(provider), Mockito.any(), Mockito.any())) + .thenReturn(table); + mockedStaticUtils + .when(() -> PolarisCatalogUtils.loadV2SparkTable(Mockito.any())) + .thenCallRealMethod(); + } + + Table createdTable = + sparkCatalog.createTable(identifier, schema, new Transform[0], properties); + Table loadedTable = sparkCatalog.loadTable(identifier); + + // verify Spark V1 table is returned for non-iceberg tables assertThat(createdTable).isInstanceOf(V1Table.class); assertThat(loadedTable).isInstanceOf(V1Table.class); } diff --git a/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/rest/DeserializationTest.java b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/rest/DeserializationTest.java index 0f7d3c99b3..6c2bb99dc3 100644 --- a/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/rest/DeserializationTest.java +++ b/plugins/spark/v3.5/spark/src/test/java/org/apache/polaris/spark/rest/DeserializationTest.java @@ -22,6 +22,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Stream; @@ -66,11 +69,11 @@ public void setUp() { @ParameterizedTest @MethodSource("genericTableTestCases") public void testLoadGenericTableRESTResponse( - String baseLocation, String doc, Map properties) + String baseLocation, String doc, Map properties, String format) throws JsonProcessingException { GenericTable.Builder tableBuilder = GenericTable.builder() - .setFormat("delta") + .setFormat(format) .setName("test-table") .setProperties(properties) .setDoc(doc); @@ -82,7 +85,7 @@ public void testLoadGenericTableRESTResponse( String json = mapper.writeValueAsString(response); LoadGenericTableRESTResponse deserializedResponse = mapper.readValue(json, LoadGenericTableRESTResponse.class); - assertThat(deserializedResponse.getTable().getFormat()).isEqualTo("delta"); + assertThat(deserializedResponse.getTable().getFormat()).isEqualTo(format); assertThat(deserializedResponse.getTable().getName()).isEqualTo("test-table"); assertThat(deserializedResponse.getTable().getDoc()).isEqualTo(doc); assertThat(deserializedResponse.getTable().getProperties().size()).isEqualTo(properties.size()); @@ -92,13 +95,13 @@ public void testLoadGenericTableRESTResponse( @ParameterizedTest @MethodSource("genericTableTestCases") public void testCreateGenericTableRESTRequest( - String baseLocation, String doc, Map properties) + String baseLocation, String doc, Map properties, String format) throws JsonProcessingException { CreateGenericTableRESTRequest request = new CreateGenericTableRESTRequest( CreateGenericTableRequest.builder() .setName("test-table") - .setFormat("delta") + .setFormat(format) .setDoc(doc) .setBaseLocation(baseLocation) .setProperties(properties) @@ -107,7 +110,7 @@ public void testCreateGenericTableRESTRequest( CreateGenericTableRESTRequest deserializedRequest = mapper.readValue(json, CreateGenericTableRESTRequest.class); assertThat(deserializedRequest.getName()).isEqualTo("test-table"); - assertThat(deserializedRequest.getFormat()).isEqualTo("delta"); + assertThat(deserializedRequest.getFormat()).isEqualTo(format); assertThat(deserializedRequest.getDoc()).isEqualTo(doc); assertThat(deserializedRequest.getProperties().size()).isEqualTo(properties.size()); assertThat(deserializedRequest.getBaseLocation()).isEqualTo(baseLocation); @@ -159,11 +162,14 @@ private static Stream genericTableTestCases() { var properties = Maps.newHashMap(); properties.put("location", "s3://path/to/table/"); var baseLocation = "s3://path/to/table/"; - return Stream.of( - Arguments.of(null, doc, properties), - Arguments.of(baseLocation, doc, properties), - Arguments.of(null, null, Maps.newHashMap()), - Arguments.of(baseLocation, doc, Maps.newHashMap()), - Arguments.of(baseLocation, null, properties)); + List args = new ArrayList<>(); + for (String format : Arrays.asList("delta", "hudi")) { + args.add(Arguments.of(null, doc, properties, format)); + args.add(Arguments.of(baseLocation, doc, properties, format)); + args.add(Arguments.of(null, null, Maps.newHashMap(), format)); + args.add(Arguments.of(baseLocation, doc, Maps.newHashMap(), format)); + args.add(Arguments.of(baseLocation, null, properties, format)); + } + return args.stream(); } }