Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.functions;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Set;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.util.BucketUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;
import org.apache.spark.unsafe.types.UTF8String;

/**
* A Spark function implementation for the Iceberg bucket transform.
*
* <p>Example usage: {@code SELECT system.bucket(128, 'abc')}, which returns the bucket 122.
*
* <p>Note that for performance reasons, the given input number of buckets is not validated in the
* implementations used in code-gen. The number of buckets must be positive to give meaningful
* results.
*/
public class BucketFunction implements UnboundFunction {
private static final int NUM_BUCKETS_ORDINAL = 0;
private static final int VALUE_ORDINAL = 1;
private static final Set<DataType> SUPPORTED_NUM_BUCKETS_TYPES =
ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType);

@Override
public BoundFunction bind(StructType inputType) {
if (inputType.size() != 2) {
throw new UnsupportedOperationException(
"Wrong number of inputs (expected numBuckets and value)");
}

StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL];
StructField valueField = inputType.fields()[VALUE_ORDINAL];

if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) {
throw new UnsupportedOperationException(
"Expected number of buckets to be tinyint, shortint or int");
}

DataType type = valueField.dataType();
if (type instanceof DateType) {
return new BucketInt(type);
} else if (type instanceof ByteType
|| type instanceof ShortType
|| type instanceof IntegerType) {
return new BucketInt(DataTypes.IntegerType);
} else if (type instanceof LongType) {
return new BucketLong(type);
} else if (type instanceof TimestampType) {
return new BucketLong(type);
} else if (type instanceof DecimalType) {
return new BucketDecimal(type);
} else if (type instanceof StringType) {
return new BucketString();
} else if (type instanceof BinaryType) {
return new BucketBinary();
} else {
throw new UnsupportedOperationException(
"Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary");
}
}

@Override
public String description() {
return name()
+ "(numBuckets, col) - Call Iceberg's bucket transform\n"
+ " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n"
+ " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)";
}

@Override
public String name() {
return "bucket";
}

public abstract static class BucketBase implements ScalarFunction<Integer> {
public static int apply(int numBuckets, int hashedValue) {
return (hashedValue & Integer.MAX_VALUE) % numBuckets;
}

@Override
public String name() {
return "bucket";
}

@Override
public DataType resultType() {
return DataTypes.IntegerType;
}
}

// Used for both int and date - tinyint and smallint are upcasted to int by Spark.
public static class BucketInt extends BucketBase {
private final DataType sqlType;

// magic method used in codegen
public static int invoke(int numBuckets, int value) {
return apply(numBuckets, hash(value));
}

// Visible for testing
public static int hash(int value) {
return BucketUtil.hash(value);
}

public BucketInt(DataType sqlType) {
this.sqlType = sqlType;
}

@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}

@Override
public String canonicalName() {
return String.format("iceberg.bucket(%s)", sqlType.catalogString());
}

@Override
public Integer produceResult(InternalRow input) {
// return null for null input to match what Spark does in the code-generated versions.
return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)
? null
: invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getInt(VALUE_ORDINAL));
}
}

// Used for both BigInt and Timestamp
public static class BucketLong extends BucketBase {
private final DataType sqlType;

// magic function for usage with codegen - needs to be static
public static int invoke(int numBuckets, long value) {
return apply(numBuckets, hash(value));
}

// Visible for testing
public static int hash(long value) {
return BucketUtil.hash(value);
}

public BucketLong(DataType sqlType) {
this.sqlType = sqlType;
}

@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}

@Override
public String canonicalName() {
return String.format("iceberg.bucket(%s)", sqlType.catalogString());
}

@Override
public Integer produceResult(InternalRow input) {
return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)
? null
: invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getLong(VALUE_ORDINAL));
}
}

public static class BucketString extends BucketBase {
// magic function for usage with codegen
public static Integer invoke(int numBuckets, UTF8String value) {
if (value == null) {
return null;
}

// TODO - We can probably hash the bytes directly given they're already UTF-8 input.
return apply(numBuckets, hash(value.toString()));
}

// Visible for testing
public static int hash(String value) {
return BucketUtil.hash(value);
}

@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.StringType};
}

@Override
public String canonicalName() {
return "iceberg.bucket(string)";
}

@Override
public Integer produceResult(InternalRow input) {
return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)
? null
: invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getUTF8String(VALUE_ORDINAL));
}
}

public static class BucketBinary extends BucketBase {
public static Integer invoke(int numBuckets, byte[] value) {
if (value == null) {
return null;
}

return apply(numBuckets, hash(ByteBuffer.wrap(value)));
}

// Visible for testing
public static int hash(ByteBuffer value) {
return BucketUtil.hash(value);
}

@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType};
}

@Override
public Integer produceResult(InternalRow input) {
return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)
? null
: invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getBinary(VALUE_ORDINAL));
}

@Override
public String canonicalName() {
return "iceberg.bucket(binary)";
}
}

public static class BucketDecimal extends BucketBase {
private final DataType sqlType;
private final int precision;
private final int scale;

// magic method used in codegen
public static Integer invoke(int numBuckets, Decimal value) {
if (value == null) {
return null;
}

return apply(numBuckets, hash(value.toJavaBigDecimal()));
}

// Visible for testing
public static int hash(BigDecimal value) {
return BucketUtil.hash(value);
}

public BucketDecimal(DataType sqlType) {
this.sqlType = sqlType;
this.precision = ((DecimalType) sqlType).precision();
this.scale = ((DecimalType) sqlType).scale();
}

@Override
public DataType[] inputTypes() {
return new DataType[] {DataTypes.IntegerType, sqlType};
}

@Override
public Integer produceResult(InternalRow input) {
return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL)
? null
: invoke(
input.getInt(NUM_BUCKETS_ORDINAL), input.getDecimal(VALUE_ORDINAL, precision, scale));
}

@Override
public String canonicalName() {
return "iceberg.bucket(decimal)";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ private SparkFunctions() {}
private static final Map<String, UnboundFunction> FUNCTIONS =
ImmutableMap.of(
"iceberg_version", new IcebergVersionFunction(),
"bucket", new BucketFunction(),
"truncate", new TruncateFunction());

private static final List<String> FUNCTION_NAMES = ImmutableList.copyOf(FUNCTIONS.keySet());
Expand Down
Loading