Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static void registerBucketUDF(SparkSession session, String funcName, Data
SparkTypeToType typeConverter = new SparkTypeToType();
Type sourceIcebergType = typeConverter.atomic(sourceType);
Transform<Object, Integer> bucket = Transforms.bucket(sourceIcebergType, numBuckets);
session.udf().register(funcName, bucket::apply, DataTypes.IntegerType);
session.udf().register(funcName,
value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ public static Object convert(Type type, Object object) {
return DateTimeUtils.fromJavaTimestamp((Timestamp) object);
case BINARY:
return ByteBuffer.wrap((byte[]) object);
case BOOLEAN:
case INTEGER:
return ((Number) object).intValue();
case BOOLEAN:
case LONG:
case FLOAT:
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@

package org.apache.iceberg.spark.source;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import org.apache.iceberg.spark.IcebergSpark;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.CharType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.VarcharType;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
Expand All @@ -48,23 +57,132 @@ public static void stopSpark() {
}

@Test
public void testRegisterBucketUDF() {
public void testRegisterIntegerBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterShortBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterByteBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
results.get(0).getInt(0));
}

@Test
public void testRegisterLongBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16);
List<Row> results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
Assert.assertEquals(1, results2.size());
List<Row> results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L),
results2.get(0).getInt(0));
results.get(0).getInt(0));
}

@Test
public void testRegisterStringBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16);
List<Row> results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
Assert.assertEquals(1, results3.size());
List<Row> results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results.get(0).getInt(0));
}

@Test
public void testRegisterCharBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results.get(0).getInt(0));
}

@Test
public void testRegisterVarCharBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
results3.get(0).getInt(0));
results.get(0).getInt(0));
}

@Test
public void testRegisterDateBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16);
List<Row> results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16)
.apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))),
results.get(0).getInt(0));
}

@Test
public void testRegisterTimestampBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16)
.apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))),
results.get(0).getInt(0));
}

@Test
public void testRegisterBinaryBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16)
.apply(ByteBuffer.wrap(new byte[]{0x00, 0x20, 0x00, 0x1F})),
results.get(0).getInt(0));
}

@Test
public void testRegisterDecimalBucketUDF() {
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16);
List<Row> results =
spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList();
Assert.assertEquals(1, results.size());
Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16)
.apply(new BigDecimal("11.11")),
results.get(0).getInt(0));
}

@Test
public void testRegisterBooleanBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: boolean");
}

@Test
public void testRegisterDoubleBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: double");
}

@Test
public void testRegisterFloatBucketUDF() {
Assertions.assertThatThrownBy(() ->
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot bucket by type: float");
}
}