diff --git a/api/src/main/java/org/apache/iceberg/Metrics.java b/api/src/main/java/org/apache/iceberg/Metrics.java index d5367c448175..8a213ad8f839 100644 --- a/api/src/main/java/org/apache/iceberg/Metrics.java +++ b/api/src/main/java/org/apache/iceberg/Metrics.java @@ -24,8 +24,8 @@ import java.io.ObjectOutputStream; import java.io.Serializable; import java.nio.ByteBuffer; -import java.util.HashMap; import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.util.ByteBuffers; /** @@ -230,7 +230,7 @@ private static Map readByteBufferMap(ObjectInputStream in) return null; } else { - Map result = new HashMap<>(size); + Map result = Maps.newHashMapWithExpectedSize(size); for (int i = 0; i < size; ++i) { Integer key = (Integer) in.readObject(); diff --git a/build.gradle b/build.gradle index 8a8812ae3ec8..40521e72795d 100644 --- a/build.gradle +++ b/build.gradle @@ -1099,11 +1099,15 @@ project(":iceberg-spark3-extensions") { exclude group: 'org.apache.arrow' } + testCompile project(path: ':iceberg-data', configuration: 'testArtifacts') + testCompile project(path: ':iceberg-orc', configuration: 'testArtifacts') testCompile project(path: ':iceberg-api', configuration: 'testArtifacts') testCompile project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts') testCompile project(path: ':iceberg-spark', configuration: 'testArtifacts') testCompile project(path: ':iceberg-spark3', configuration: 'testArtifacts') + testCompile "org.apache.avro:avro" + spark31Implementation("org.apache.spark:spark-hive_2.12:${project.ext.Spark31Version}") { exclude group: 'org.apache.avro', module: 'avro' exclude group: 'org.apache.arrow' diff --git a/core/src/main/java/org/apache/iceberg/avro/Avro.java b/core/src/main/java/org/apache/iceberg/avro/Avro.java index 59715b01b0e7..3b4ba48acd37 100644 --- a/core/src/main/java/org/apache/iceberg/avro/Avro.java +++ b/core/src/main/java/org/apache/iceberg/avro/Avro.java @@ -636,4 +636,12 @@ public AvroIterable build() { } } + /** + * Returns number of rows in specified Avro file + * @param file Avro file + * @return number of rows in file + */ + public static long rowCount(InputFile file) { + return AvroIO.findStartingRowPos(file::newStream, Long.MAX_VALUE); + } } diff --git a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java index 27508470254d..a432c7639386 100644 --- a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java +++ b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java @@ -36,7 +36,9 @@ import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.avro.Avro; import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMapping; import org.apache.iceberg.orc.OrcMetrics; import org.apache.iceberg.parquet.ParquetUtil; @@ -91,7 +93,9 @@ private static List listAvroPartition(Map partitionPat return Arrays.stream(fs.listStatus(partition, HIDDEN_PATH_FILTER)) .filter(FileStatus::isFile) .map(stat -> { - Metrics metrics = new Metrics(-1L, null, null, null); + InputFile file = HadoopInputFile.fromLocation(stat.getPath().toString(), conf); + long rowCount = Avro.rowCount(file); + Metrics metrics = new Metrics(rowCount, null, null, null); String partitionKey = spec.fields().stream() .map(PartitionField::name) .map(name -> String.format("%s=%s", name, partitionPath.get(name))) diff --git a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java index 647c6260a4e7..8b415dc4554b 100644 --- a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java +++ b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java @@ -288,8 +288,7 @@ protected void doCommit(TableMetadata base, TableMetadata metadata) { throw new RuntimeException("Interrupted during commit", e); } finally { - cleanupMetadataAndUnlock(commitStatus, newMetadataLocation, lockId); - tableLevelMutex.unlock(); + cleanupMetadataAndUnlock(commitStatus, newMetadataLocation, lockId, tableLevelMutex); } } @@ -471,7 +470,8 @@ long acquireLock() throws UnknownHostException, TException, InterruptedException return lockId; } - private void cleanupMetadataAndUnlock(CommitStatus commitStatus, String metadataLocation, Optional lockId) { + private void cleanupMetadataAndUnlock(CommitStatus commitStatus, String metadataLocation, Optional lockId, + ReentrantLock tableLevelMutex) { try { if (commitStatus == CommitStatus.FAILURE) { // If we are sure the commit failed, clean up the uncommitted metadata file @@ -482,6 +482,7 @@ private void cleanupMetadataAndUnlock(CommitStatus commitStatus, String metadata throw e; } finally { unlock(lockId); + tableLevelMutex.unlock(); } } diff --git a/mr/src/main/java/org/apache/iceberg/mr/Catalogs.java b/mr/src/main/java/org/apache/iceberg/mr/Catalogs.java index b5cc63b42955..0bf04731124b 100644 --- a/mr/src/main/java/org/apache/iceberg/mr/Catalogs.java +++ b/mr/src/main/java/org/apache/iceberg/mr/Catalogs.java @@ -19,7 +19,6 @@ package org.apache.iceberg.mr; -import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.Properties; @@ -39,6 +38,7 @@ import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Streams; /** @@ -150,7 +150,7 @@ public static Table createTable(Configuration conf, Properties props) { String catalogName = props.getProperty(InputFormatConfig.CATALOG_NAME); // Create a table property map without the controlling properties - Map map = new HashMap<>(props.size()); + Map map = Maps.newHashMapWithExpectedSize(props.size()); for (Object key : props.keySet()) { if (!PROPERTIES_TO_REMOVE.contains(key)) { map.put(key.toString(), props.get(key).toString()); @@ -202,7 +202,15 @@ public static boolean dropTable(Configuration conf, Properties props) { */ public static boolean hiveCatalog(Configuration conf, Properties props) { String catalogName = props.getProperty(InputFormatConfig.CATALOG_NAME); - return CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE.equalsIgnoreCase(getCatalogType(conf, catalogName)); + String catalogType = getCatalogType(conf, catalogName); + if (catalogType != null) { + return CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE.equalsIgnoreCase(catalogType); + } + catalogType = getCatalogType(conf, ICEBERG_DEFAULT_CATALOG_NAME); + if (catalogType != null) { + return CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE.equalsIgnoreCase(catalogType); + } + return getCatalogProperties(conf, catalogName, catalogType).get(CatalogProperties.CATALOG_IMPL) == null; } @VisibleForTesting @@ -279,9 +287,7 @@ private static String getCatalogType(Configuration conf, String catalogName) { } } else { String catalogType = conf.get(InputFormatConfig.CATALOG); - if (catalogType == null) { - return CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE; - } else if (catalogType.equals(LOCATION)) { + if (catalogType != null && catalogType.equals(LOCATION)) { return NO_CATALOG_TYPE; } else { return catalogType; diff --git a/mr/src/main/java/org/apache/iceberg/mr/hive/Deserializer.java b/mr/src/main/java/org/apache/iceberg/mr/hive/Deserializer.java index 458affdd7c60..47e9f3e0537d 100644 --- a/mr/src/main/java/org/apache/iceberg/mr/hive/Deserializer.java +++ b/mr/src/main/java/org/apache/iceberg/mr/hive/Deserializer.java @@ -33,6 +33,7 @@ import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.mr.hive.serde.objectinspector.WriteObjectInspector; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.schema.SchemaWithPartnerVisitor; import org.apache.iceberg.types.Type.PrimitiveType; import org.apache.iceberg.types.Types.ListType; @@ -232,7 +233,7 @@ private static class FixNameMappingObjectInspectorPair extends ObjectInspectorPa FixNameMappingObjectInspectorPair(Schema schema, ObjectInspectorPair pair) { super(pair.writerInspector(), pair.sourceInspector()); - this.sourceNameMap = new HashMap<>(schema.columns().size()); + this.sourceNameMap = Maps.newHashMapWithExpectedSize(schema.columns().size()); List fields = ((StructObjectInspector) sourceInspector()).getAllStructFieldRefs(); for (int i = 0; i < schema.columns().size(); ++i) { diff --git a/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergSerDe.java b/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergSerDe.java index 707db6f808ec..ce425c6d10a4 100644 --- a/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergSerDe.java +++ b/mr/src/main/java/org/apache/iceberg/mr/hive/HiveIcebergSerDe.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; @@ -46,6 +45,7 @@ import org.apache.iceberg.mr.InputFormatConfig; import org.apache.iceberg.mr.hive.serde.objectinspector.IcebergObjectInspector; import org.apache.iceberg.mr.mapred.Container; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,7 +55,7 @@ public class HiveIcebergSerDe extends AbstractSerDe { private ObjectInspector inspector; private Schema tableSchema; - private Map deserializers = new HashMap<>(1); + private Map deserializers = Maps.newHashMapWithExpectedSize(1); private Container row = new Container<>(); @Override diff --git a/mr/src/test/java/org/apache/iceberg/mr/TestCatalogs.java b/mr/src/test/java/org/apache/iceberg/mr/TestCatalogs.java index 0d5b1d63bb23..0e76ea1115ec 100644 --- a/mr/src/test/java/org/apache/iceberg/mr/TestCatalogs.java +++ b/mr/src/test/java/org/apache/iceberg/mr/TestCatalogs.java @@ -197,6 +197,7 @@ public void testLegacyLoadCatalogDefault() { Optional defaultCatalog = Catalogs.loadCatalog(conf, null); Assert.assertTrue(defaultCatalog.isPresent()); Assertions.assertThat(defaultCatalog.get()).isInstanceOf(HiveCatalog.class); + Assert.assertTrue(Catalogs.hiveCatalog(conf, new Properties())); } @Test @@ -205,6 +206,7 @@ public void testLegacyLoadCatalogHive() { Optional hiveCatalog = Catalogs.loadCatalog(conf, null); Assert.assertTrue(hiveCatalog.isPresent()); Assertions.assertThat(hiveCatalog.get()).isInstanceOf(HiveCatalog.class); + Assert.assertTrue(Catalogs.hiveCatalog(conf, new Properties())); } @Test @@ -214,6 +216,7 @@ public void testLegacyLoadCatalogHadoop() { Optional hadoopCatalog = Catalogs.loadCatalog(conf, null); Assert.assertTrue(hadoopCatalog.isPresent()); Assertions.assertThat(hadoopCatalog.get()).isInstanceOf(HadoopCatalog.class); + Assert.assertFalse(Catalogs.hiveCatalog(conf, new Properties())); } @Test @@ -223,12 +226,14 @@ public void testLegacyLoadCatalogCustom() { Optional customHadoopCatalog = Catalogs.loadCatalog(conf, null); Assert.assertTrue(customHadoopCatalog.isPresent()); Assertions.assertThat(customHadoopCatalog.get()).isInstanceOf(CustomHadoopCatalog.class); + Assert.assertFalse(Catalogs.hiveCatalog(conf, new Properties())); } @Test public void testLegacyLoadCatalogLocation() { conf.set(InputFormatConfig.CATALOG, Catalogs.LOCATION); Assert.assertFalse(Catalogs.loadCatalog(conf, null).isPresent()); + Assert.assertFalse(Catalogs.hiveCatalog(conf, new Properties())); } @Test @@ -241,9 +246,13 @@ public void testLegacyLoadCatalogUnknown() { @Test public void testLoadCatalogDefault() { - Optional defaultCatalog = Catalogs.loadCatalog(conf, "barCatalog"); + String catalogName = "barCatalog"; + Optional defaultCatalog = Catalogs.loadCatalog(conf, catalogName); Assert.assertTrue(defaultCatalog.isPresent()); Assertions.assertThat(defaultCatalog.get()).isInstanceOf(HiveCatalog.class); + Properties properties = new Properties(); + properties.put(InputFormatConfig.CATALOG_NAME, catalogName); + Assert.assertTrue(Catalogs.hiveCatalog(conf, properties)); } @Test @@ -254,6 +263,9 @@ public void testLoadCatalogHive() { Optional hiveCatalog = Catalogs.loadCatalog(conf, catalogName); Assert.assertTrue(hiveCatalog.isPresent()); Assertions.assertThat(hiveCatalog.get()).isInstanceOf(HiveCatalog.class); + Properties properties = new Properties(); + properties.put(InputFormatConfig.CATALOG_NAME, catalogName); + Assert.assertTrue(Catalogs.hiveCatalog(conf, properties)); } @Test @@ -267,6 +279,9 @@ public void testLoadCatalogHadoop() { Assert.assertTrue(hadoopCatalog.isPresent()); Assertions.assertThat(hadoopCatalog.get()).isInstanceOf(HadoopCatalog.class); Assert.assertEquals("HadoopCatalog{name=barCatalog, location=/tmp/mylocation}", hadoopCatalog.get().toString()); + Properties properties = new Properties(); + properties.put(InputFormatConfig.CATALOG_NAME, catalogName); + Assert.assertFalse(Catalogs.hiveCatalog(conf, properties)); } @Test @@ -279,6 +294,9 @@ public void testLoadCatalogHadoopWithLegacyWarehouseLocation() { Assert.assertTrue(hadoopCatalog.isPresent()); Assertions.assertThat(hadoopCatalog.get()).isInstanceOf(HadoopCatalog.class); Assert.assertEquals("HadoopCatalog{name=barCatalog, location=/tmp/mylocation}", hadoopCatalog.get().toString()); + Properties properties = new Properties(); + properties.put(InputFormatConfig.CATALOG_NAME, catalogName); + Assert.assertFalse(Catalogs.hiveCatalog(conf, properties)); } @Test @@ -291,6 +309,9 @@ public void testLoadCatalogCustom() { Optional customHadoopCatalog = Catalogs.loadCatalog(conf, catalogName); Assert.assertTrue(customHadoopCatalog.isPresent()); Assertions.assertThat(customHadoopCatalog.get()).isInstanceOf(CustomHadoopCatalog.class); + Properties properties = new Properties(); + properties.put(InputFormatConfig.CATALOG_NAME, catalogName); + Assert.assertFalse(Catalogs.hiveCatalog(conf, properties)); } @Test diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java b/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java index 0fad4dae139e..a873be5f1aa6 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcMetrics.java @@ -44,7 +44,6 @@ import org.apache.iceberg.hadoop.HadoopInputFile; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMapping; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Conversions; @@ -196,11 +195,17 @@ private static Optional fromOrcMin(Type type, ColumnStatistics colum min = Math.toIntExact((long) min); } } else if (columnStats instanceof DoubleColumnStatistics) { - // since Orc includes NaN for upper/lower bounds of floating point columns, and we don't want this behavior, - // we have tracked metrics for such columns ourselves and thus do not need to rely on Orc's column statistics. - Preconditions.checkNotNull(fieldMetrics, - "[BUG] Float or double type columns should have metrics being tracked by Iceberg Orc writers"); - min = fieldMetrics.lowerBound(); + if (fieldMetrics != null) { + // since Orc includes NaN for upper/lower bounds of floating point columns, and we don't want this behavior, + // we have tracked metrics for such columns ourselves and thus do not need to rely on Orc's column statistics. + min = fieldMetrics.lowerBound(); + } else { + // imported files will not have metrics that were tracked by Iceberg, so fall back to the file's metrics. + min = replaceNaN(((DoubleColumnStatistics) columnStats).getMinimum(), Double.NEGATIVE_INFINITY); + if (type.typeId() == Type.TypeID.FLOAT) { + min = ((Double) min).floatValue(); + } + } } else if (columnStats instanceof StringColumnStatistics) { min = ((StringColumnStatistics) columnStats).getMinimum(); } else if (columnStats instanceof DecimalColumnStatistics) { @@ -234,11 +239,17 @@ private static Optional fromOrcMax(Type type, ColumnStatistics colum max = Math.toIntExact((long) max); } } else if (columnStats instanceof DoubleColumnStatistics) { - // since Orc includes NaN for upper/lower bounds of floating point columns, and we don't want this behavior, - // we have tracked metrics for such columns ourselves and thus do not need to rely on Orc's column statistics. - Preconditions.checkNotNull(fieldMetrics, - "[BUG] Float or double type columns should have metrics being tracked by Iceberg Orc writers"); - max = fieldMetrics.upperBound(); + if (fieldMetrics != null) { + // since Orc includes NaN for upper/lower bounds of floating point columns, and we don't want this behavior, + // we have tracked metrics for such columns ourselves and thus do not need to rely on Orc's column statistics. + max = fieldMetrics.upperBound(); + } else { + // imported files will not have metrics that were tracked by Iceberg, so fall back to the file's metrics. + max = replaceNaN(((DoubleColumnStatistics) columnStats).getMaximum(), Double.POSITIVE_INFINITY); + if (type.typeId() == Type.TypeID.FLOAT) { + max = ((Double) max).floatValue(); + } + } } else if (columnStats instanceof StringColumnStatistics) { max = ((StringColumnStatistics) columnStats).getMaximum(); } else if (columnStats instanceof DecimalColumnStatistics) { @@ -262,6 +273,10 @@ private static Optional fromOrcMax(Type type, ColumnStatistics colum return Optional.ofNullable(Conversions.toByteBuffer(type, truncateIfNeeded(Bound.UPPER, type, max, metricsMode))); } + private static Object replaceNaN(double value, double replacement) { + return Double.isNaN(value) ? replacement : value; + } + private static Object truncateIfNeeded(Bound bound, Type type, Object value, MetricsMode metricsMode) { // Out of the two types which could be truncated, string or binary, ORC only supports string bounds. // Therefore, truncation will be applied if needed only on string type. diff --git a/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java index ac659f6c7b13..862626d0cd6d 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java +++ b/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java @@ -34,6 +34,7 @@ public static void registerBucketUDF(SparkSession session, String funcName, Data SparkTypeToType typeConverter = new SparkTypeToType(); Type sourceIcebergType = typeConverter.atomic(sourceType); Transform bucket = Transforms.bucket(sourceIcebergType, numBuckets); - session.udf().register(funcName, bucket::apply, DataTypes.IntegerType); + session.udf().register(funcName, + value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java index 92c812a9b979..ef453c0cef2b 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java @@ -79,8 +79,9 @@ public static Object convert(Type type, Object object) { return DateTimeUtils.fromJavaTimestamp((Timestamp) object); case BINARY: return ByteBuffer.wrap((byte[]) object); - case BOOLEAN: case INTEGER: + return ((Number) object).intValue(); + case BOOLEAN: case LONG: case FLOAT: case DOUBLE: diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java index 14785d7f27ca..f8ebe21b58a4 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java @@ -19,13 +19,22 @@ package org.apache.iceberg.spark.source; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; import java.util.List; import org.apache.iceberg.spark.IcebergSpark; import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Types; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.CharType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.VarcharType; +import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; @@ -48,23 +57,132 @@ public static void stopSpark() { } @Test - public void testRegisterBucketUDF() { + public void testRegisterIntegerBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16); List results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList(); Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), results.get(0).getInt(0)); + } + + @Test + public void testRegisterShortBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16); + List results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterByteBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16); + List results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + @Test + public void testRegisterLongBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16); - List results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); - Assert.assertEquals(1, results2.size()); + List results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L), - results2.get(0).getInt(0)); + results.get(0).getInt(0)); + } + @Test + public void testRegisterStringBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16); - List results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); - Assert.assertEquals(1, results3.size()); + List results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterVarCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), - results3.get(0).getInt(0)); + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDateBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16); + List results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16) + .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterTimestampBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16); + List results = + spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16) + .apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBinaryBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16); + List results = + spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16) + .apply(ByteBuffer.wrap(new byte[]{0x00, 0x20, 0x00, 0x1F})), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDecimalBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16); + List results = + spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16) + .apply(new BigDecimal("11.11")), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBooleanBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: boolean"); + } + + @Test + public void testRegisterDoubleBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: double"); + } + + @Test + public void testRegisterFloatBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: float"); } } diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index dea01a10a647..8096688e9355 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -24,8 +24,25 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetricsConfig; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,12 +52,15 @@ import org.apache.spark.sql.types.StructType; import org.junit.After; import org.junit.Assert; +import org.junit.Assume; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import static org.apache.iceberg.types.Types.NestedField.optional; + public class TestAddFilesProcedure extends SparkExtensionsTestBase { private final String sourceTableName = "source_table"; @@ -106,6 +126,59 @@ public void addDataUnpartitionedOrc() { sql("SELECT * FROM %s ORDER BY id", tableName)); } + @Test + public void addAvroFile() throws Exception { + // Spark Session Catalog cannot load metadata tables + // with "The namespace in session catalog must have exactly one name part" + Assume.assumeFalse(catalogName.equals("spark_catalog")); + + // Create an Avro file + + Schema schema = SchemaBuilder.record("record").fields() + .requiredInt("id") + .requiredString("data") + .endRecord(); + GenericRecord record1 = new GenericData.Record(schema); + record1.put("id", 1L); + record1.put("data", "a"); + GenericRecord record2 = new GenericData.Record(schema); + record2.put("id", 2L); + record2.put("data", "b"); + File outputFile = temp.newFile("test.avro"); + + DatumWriter datumWriter = new GenericDatumWriter(schema); + DataFileWriter dataFileWriter = new DataFileWriter(datumWriter); + dataFileWriter.create(schema, outputFile); + dataFileWriter.append(record1); + dataFileWriter.append(record2); + dataFileWriter.close(); + + String createIceberg = + "CREATE TABLE %s (id Long, data String) USING iceberg"; + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')", + catalogName, tableName, outputFile.getPath()); + Assert.assertEquals(1L, result); + + List expected = Lists.newArrayList( + new Object[]{1L, "a"}, + new Object[]{2L, "b"} + ); + + assertEquals("Iceberg table contains correct data", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + List actualRecordCount = sql("select %s from %s.files", + DataFile.RECORD_COUNT.name(), + tableName); + List expectedRecordCount = Lists.newArrayList(); + expectedRecordCount.add(new Object[]{2L}); + assertEquals("Iceberg file metadata should have correct metadata count", + expectedRecordCount, actualRecordCount); + } + // TODO Adding spark-avro doesn't work in tests @Ignore public void addDataUnpartitionedAvro() { @@ -444,6 +517,42 @@ public void invalidDataImportPartitioned() { catalogName, tableName, fileTableDir.getAbsolutePath())); } + @Test + public void addOrcFileWithDoubleAndFloatColumns() throws Exception { + // Spark Session Catalog cannot load metadata tables + // with "The namespace in session catalog must have exactly one name part" + Assume.assumeFalse(catalogName.equals("spark_catalog")); + + // Create an ORC file + File outputFile = temp.newFile("test.orc"); + final int numRows = 5; + List expectedRecords = createOrcFile(outputFile, numRows); + String createIceberg = + "CREATE TABLE %s (x float, y double, z long) USING iceberg"; + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`orc`.`%s`')", + catalogName, tableName, outputFile.getPath()); + Assert.assertEquals(1L, result); + + List expected = expectedRecords.stream() + .map(record -> new Object[]{record.get(0), record.get(1), record.get(2)}) + .collect(Collectors.toList()); + + // x goes 2.00, 1.99, 1.98, ... + assertEquals("Iceberg table contains correct data", + expected, + sql("SELECT * FROM %s ORDER BY x DESC", tableName)); + + List actualRecordCount = sql("select %s from %s.files", + DataFile.RECORD_COUNT.name(), + tableName); + List expectedRecordCount = Lists.newArrayList(); + expectedRecordCount.add(new Object[]{(long) numRows}); + assertEquals("Iceberg file metadata should have correct metadata count", + expectedRecordCount, actualRecordCount); + } + private static final StructField[] struct = { new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("name", DataTypes.StringType, false, Metadata.empty()), @@ -534,4 +643,36 @@ private void createPartitionedHiveTable() { partitionedDF.write().insertInto(sourceTableName); partitionedDF.write().insertInto(sourceTableName); } + + // Update this to not write a file for import using Iceberg's ID numbers + public List createOrcFile(File orcFile, int numRows) throws IOException { + // Needs to be deleted but depend on the rule to delete the file for us again at the end. + if (orcFile.exists()) { + orcFile.delete(); + } + final org.apache.iceberg.Schema icebergSchema = new org.apache.iceberg.Schema( + optional(1, "x", Types.FloatType.get()), + optional(2, "y", Types.DoubleType.get()), + optional(3, "z", Types.LongType.get()) + ); + + List records = Lists.newArrayListWithExpectedSize(numRows); + for (int i = 0; i < numRows; i += 1) { + Record record = org.apache.iceberg.data.GenericRecord.create(icebergSchema); + record.setField("x", ((float) (100 - i)) / 100F + 1.0F); // 2.0f, 1.99f, 1.98f, ... + record.setField("y", ((double) i) / 100.0D + 2.0D); // 2.0d, 2.01d, 2.02d, ... + record.setField("z", 5_000_000_000L + i); + records.add(record); + } + + OutputFile outFile = Files.localOutput(orcFile); + try (FileAppender appender = org.apache.iceberg.orc.ORC.write(outFile) + .schema(icebergSchema) + .metricsConfig(MetricsConfig.fromProperties(ImmutableMap.of("write.metadata.metrics.default", "none"))) + .createWriterFunc(GenericOrcWriter::buildWriter) + .build()) { + appender.addAll(records); + } + return records; + } }