diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java index ac659f6c7b13..862626d0cd6d 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java @@ -34,6 +34,7 @@ public static void registerBucketUDF(SparkSession session, String funcName, Data SparkTypeToType typeConverter = new SparkTypeToType(); Type sourceIcebergType = typeConverter.atomic(sourceType); Transform bucket = Transforms.bucket(sourceIcebergType, numBuckets); - session.udf().register(funcName, bucket::apply, DataTypes.IntegerType); + session.udf().register(funcName, + value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType); } } diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java index 92c812a9b979..ef453c0cef2b 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java @@ -79,8 +79,9 @@ public static Object convert(Type type, Object object) { return DateTimeUtils.fromJavaTimestamp((Timestamp) object); case BINARY: return ByteBuffer.wrap((byte[]) object); - case BOOLEAN: case INTEGER: + return ((Number) object).intValue(); + case BOOLEAN: case LONG: case FLOAT: case DOUBLE: diff --git a/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java index e83709a3b2a9..c275daee5f7e 100644 --- a/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java +++ b/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java @@ -19,13 +19,22 @@ package org.apache.iceberg.spark.source; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; import java.util.List; import org.apache.iceberg.spark.IcebergSpark; import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Types; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.CharType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.VarcharType; +import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; @@ -48,23 +57,132 @@ public static void stopSpark() { } @Test - public void testRegisterBucketUDF() { + public void testRegisterIntegerBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16); List results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList(); Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), results.get(0).getInt(0)); + } + + @Test + public void testRegisterShortBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16); + List results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterByteBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16); + List results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1), + results.get(0).getInt(0)); + } + @Test + public void testRegisterLongBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16); - List results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); - Assert.assertEquals(1, results2.size()); + List results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList(); + Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L), - results2.get(0).getInt(0)); + results.get(0).getInt(0)); + } + @Test + public void testRegisterStringBucketUDF() { IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16); - List results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); - Assert.assertEquals(1, results3.size()); + List results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterVarCharBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16); + List results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList(); + Assert.assertEquals(1, results.size()); Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"), - results3.get(0).getInt(0)); + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDateBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16); + List results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16) + .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterTimestampBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16); + List results = + spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16) + .apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBinaryBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16); + List results = + spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16) + .apply(ByteBuffer.wrap(new byte[]{0x00, 0x20, 0x00, 0x1F})), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterDecimalBucketUDF() { + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16); + List results = + spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList(); + Assert.assertEquals(1, results.size()); + Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16) + .apply(new BigDecimal("11.11")), + results.get(0).getInt(0)); + } + + @Test + public void testRegisterBooleanBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: boolean"); + } + + @Test + public void testRegisterDoubleBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: double"); + } + + @Test + public void testRegisterFloatBucketUDF() { + Assertions.assertThatThrownBy(() -> + IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot bucket by type: float"); } }