diff --git a/api/src/main/java/org/apache/iceberg/transforms/Truncate.java b/api/src/main/java/org/apache/iceberg/transforms/Truncate.java index f0ac2c63d033..d0e301c5c63f 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/Truncate.java +++ b/api/src/main/java/org/apache/iceberg/transforms/Truncate.java @@ -31,6 +31,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Objects; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.types.Type; +import org.apache.iceberg.util.TruncateUtil; import org.apache.iceberg.util.UnicodeUtil; abstract class Truncate implements Transform { @@ -87,7 +88,7 @@ public Integer apply(Integer value) { return null; } - return value - (((value % width) + width) % width); + return TruncateUtil.truncateInt(width, value); } @Override @@ -171,7 +172,7 @@ public Long apply(Long value) { return null; } - return value - (((value % width) + width) % width); + return TruncateUtil.truncateLong(width, value); } @Override @@ -391,9 +392,7 @@ public ByteBuffer apply(ByteBuffer value) { return null; } - ByteBuffer ret = value.duplicate(); - ret.limit(Math.min(value.limit(), value.position() + length)); - return ret; + return TruncateUtil.truncateByteBuffer(length, value); } @Override @@ -480,16 +479,7 @@ public BigDecimal apply(BigDecimal value) { return null; } - BigDecimal remainder = - new BigDecimal( - value - .unscaledValue() - .remainder(unscaledWidth) - .add(unscaledWidth) - .remainder(unscaledWidth), - value.scale()); - - return value.subtract(remainder); + return TruncateUtil.truncateDecimal(unscaledWidth, value); } @Override diff --git a/api/src/main/java/org/apache/iceberg/util/TruncateUtil.java b/api/src/main/java/org/apache/iceberg/util/TruncateUtil.java new file mode 100644 index 000000000000..7be5f02fc01a --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/util/TruncateUtil.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.util; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; + +/** + * Contains the logic for various {@code truncate} transformations for various types. + * + *

This utility class allows for the logic to be reused in different scenarios where input + * validation is done at different times either in org.apache.iceberg.transforms.Truncate and within + * defined SQL functions for different compute engines for usage in SQL. + * + *

In general, the inputs to the functions should have already been validated by the calling + * code, as different classes use truncate with different preprocessing. This generally means that + * the truncation width is positive and the value to truncate is non-null. + * + *

See also {@linkplain UnicodeUtil#truncateString(CharSequence, int)} and {@link + * BinaryUtil#truncateBinary(ByteBuffer, int)} + */ +public class TruncateUtil { + + private TruncateUtil() {} + + public static ByteBuffer truncateByteBuffer(int width, ByteBuffer value) { + ByteBuffer ret = value.duplicate(); + ret.limit(Math.min(value.limit(), value.position() + width)); + return ret; + } + + public static byte truncateByte(int width, byte value) { + return (byte) (value - (((value % width) + width) % width)); + } + + public static short truncateShort(int width, short value) { + return (short) (value - (((value % width) + width) % width)); + } + + public static int truncateInt(int width, int value) { + return value - (((value % width) + width) % width); + } + + public static long truncateLong(int width, long value) { + return value - (((value % width) + width) % width); + } + + public static BigDecimal truncateDecimal(BigInteger unscaledWidth, BigDecimal value) { + BigDecimal remainder = + new BigDecimal( + value + .unscaledValue() + .remainder(unscaledWidth) + .add(unscaledWidth) + .remainder(unscaledWidth), + value.scale()); + + return value.subtract(remainder); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java index f215aa033c5a..8f2abae89970 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/BaseCatalog.java @@ -18,18 +18,27 @@ */ package org.apache.iceberg.spark; +import org.apache.iceberg.spark.functions.SparkFunctions; import org.apache.iceberg.spark.procedures.SparkProcedures; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.iceberg.spark.source.HasIcebergCatalog; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.StagingTableCatalog; import org.apache.spark.sql.connector.catalog.SupportsNamespaces; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.connector.iceberg.catalog.Procedure; import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; abstract class BaseCatalog - implements StagingTableCatalog, ProcedureCatalog, SupportsNamespaces, HasIcebergCatalog { + implements StagingTableCatalog, + ProcedureCatalog, + SupportsNamespaces, + HasIcebergCatalog, + FunctionCatalog { @Override public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException { @@ -47,4 +56,37 @@ public Procedure loadProcedure(Identifier ident) throws NoSuchProcedureException throw new NoSuchProcedureException(ident); } + + @Override + public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException { + if (namespace.length == 0 + || (namespace.length == 1 && namespace[0].equalsIgnoreCase("system"))) { + return SparkFunctions.list().stream() + .map(name -> Identifier.of(namespace, name)) + .toArray(Identifier[]::new); + } else if (namespaceExists(namespace)) { + return new Identifier[0]; + } + + throw new NoSuchNamespaceException(namespace); + } + + @Override + public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { + String[] namespace = ident.namespace(); + String name = ident.name(); + + // Allow for empty namespace as Spark's storage partitioned joins look up + // the corresponding functions to generate transforms for partitioning + // with an empty namespace, such as `bucket`. + if (namespace.length == 0 + || (namespace.length == 1 && namespace[0].equalsIgnoreCase("system"))) { + UnboundFunction func = SparkFunctions.load(name); + if (func != null) { + return func; + } + } + + throw new NoSuchFunctionException(ident); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java index ebf12cb2c22e..444244e38b42 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSessionCatalog.java @@ -27,7 +27,6 @@ import org.apache.iceberg.spark.source.HasIcebergCatalog; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; -import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; @@ -42,7 +41,6 @@ import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.catalog.TableChange; -import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -379,14 +377,4 @@ public Catalog icebergCatalog() { "Cannot return underlying Iceberg Catalog, wrapped catalog does not contain an Iceberg Catalog"); return ((HasIcebergCatalog) icebergCatalog).icebergCatalog(); } - - @Override - public Identifier[] listFunctions(String[] namespace) { - return new Identifier[0]; - } - - @Override - public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { - throw new NoSuchFunctionException(ident); - } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java new file mode 100644 index 000000000000..90cb00e301e4 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +public class SparkFunctions { + + private SparkFunctions() {} + + private static final Map FUNCTIONS = + ImmutableMap.of("truncate", new TruncateFunction()); + + private static final List FUNCTION_NAMES = ImmutableList.copyOf(FUNCTIONS.keySet()); + + // Functions that are added to all Iceberg catalogs should be accessed with either the `system` + // namespace or no namespace at all, so a list of names alone is returned. + public static List list() { + return FUNCTION_NAMES; + } + + public static UnboundFunction load(String name) { + // function resolution is case insensitive to match the existing Spark behavior for functions + UnboundFunction func = FUNCTIONS.get(name.toLowerCase(Locale.ROOT)); + return func; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java new file mode 100644 index 000000000000..b362d844770e --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.TruncateUtil; +import org.apache.iceberg.util.UnicodeUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VarcharType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Implementation of {@link UnboundFunction} that matches the truncate transformation. This + * unbound function is registered with the {@link org.apache.iceberg.spark.SparkCatalog} such that + * the function can be used as {@code truncate(width, col)} or {@code truncate(2, col)}. + * + *

Specific {@link BoundFunction} implementations are resolved based on their input types. As + * with transforms, the truncation width must be non-negative. + * + *

For efficiency in generated code, the {@code width} is not validated. It is the + * responsibility of calling code of these functions to not call truncate with a non-positive + * width. + */ +public class TruncateFunction implements UnboundFunction { + private static final List truncateableAtomicTypes = + ImmutableList.of( + DataTypes.ByteType, + DataTypes.ShortType, + DataTypes.IntegerType, + DataTypes.LongType, + DataTypes.StringType, + DataTypes.BinaryType); + + private static void validateTruncationFieldType(DataType dt) { + if (truncateableAtomicTypes.stream().noneMatch(type -> type.sameType(dt)) + && !(dt instanceof DecimalType)) { + String expectedTypes = + "[ByteType, ShortType, IntegerType, LongType, StringType, BinaryType, DecimalType]"; + throw new UnsupportedOperationException( + String.format( + "Invalid input type to truncate. Expected one of %s, but found %s", + expectedTypes, dt)); + } + } + + private static void validateTruncationWidthType(DataType widthType) { + if (!DataTypes.IntegerType.sameType(widthType) + && !DataTypes.ShortType.sameType(widthType) + && !DataTypes.ByteType.sameType(widthType)) { + throw new UnsupportedOperationException( + "Expected truncation width to be one of [ByteType, ShortType, IntegerType], but found " + + widthType); + } + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 2) { + throw new UnsupportedOperationException( + String.format( + "Invalid input type. Expected 2 fields but found %s", inputType.fields().length)); + } + + StructField widthField = inputType.apply(0); + StructField toTruncateField = inputType.apply(1); + + validateTruncationFieldType(toTruncateField.dataType()); + validateTruncationWidthType(widthField.dataType()); + + DataType toTruncateDataType = toTruncateField.dataType(); + if (toTruncateDataType instanceof ByteType) { + return new TruncateTinyInt(); + } else if (toTruncateDataType instanceof ShortType) { + return new TruncateSmallInt(); + } else if (toTruncateDataType instanceof IntegerType) { + return new TruncateInt(); + } else if (toTruncateDataType instanceof LongType) { + return new TruncateBigInt(); + } else if (toTruncateDataType instanceof DecimalType) { + return new TruncateDecimal( + ((DecimalType) toTruncateDataType).precision(), + ((DecimalType) toTruncateDataType).scale()); + } else if (toTruncateDataType instanceof StringType + || toTruncateDataType instanceof VarcharType + || toTruncateDataType instanceof CharType) { + return new TruncateString(); + } else if (toTruncateDataType instanceof BinaryType) { + return new TruncateBinary(); + } else { + throw new UnsupportedOperationException("Cannot truncate type: " + toTruncateDataType); + } + } + + @Override + public String description() { + return "Truncate - The Iceberg truncate function used for truncate partition transformations.\n" + + "\tCalled with the truncation width as the first argument: e.g. system.truncate(width, col)"; + } + + @Override + public String name() { + return "truncate"; + } + + public abstract static class TruncateBase implements ScalarFunction { + @Override + public String name() { + return "truncate"; + } + } + + public static class TruncateTinyInt extends TruncateBase { + public static byte invoke(int width, byte value) { + return TruncateUtil.truncateByte(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ByteType}; + } + + @Override + public DataType resultType() { + return DataTypes.ByteType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](tinyint)"; + } + + @Override + public Byte produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + Byte toTruncate = !input.isNullAt(1) ? input.getByte(1) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + public static class TruncateSmallInt extends TruncateBase { + // magic method used in codegen + public static short invoke(int width, short value) { + return TruncateUtil.truncateShort(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.ShortType}; + } + + @Override + public DataType resultType() { + return DataTypes.ShortType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](smallint)"; + } + + @Override + public Short produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + Short toTruncate = !input.isNullAt(1) ? input.getShort(1) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + public static class TruncateInt extends TruncateBase { + // magic method used in codegen + public static int invoke(int width, int value) { + return TruncateUtil.truncateInt(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.IntegerType}; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](int)"; + } + + @Override + public Integer produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + Integer toTruncate = !input.isNullAt(1) ? input.getInt(1) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + public static class TruncateBigInt extends TruncateBase { + // magic function for usage with codegen + public static long invoke(int width, long value) { + return TruncateUtil.truncateLong(width, value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.LongType}; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](bigint)"; + } + + @Override + public Long produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + Long toTruncate = !input.isNullAt(1) ? input.getLong(1) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + public static class TruncateString extends TruncateBase { + // magic function for usage with codegen + // todo - this can be made more efficient but first keep the implementation the same. + public static UTF8String invoke(int width, UTF8String value) { + if (value == null) { + return null; + } + + ByteBuffer bb = value.getByteBuffer(); + CharSequence charSequence = StandardCharsets.UTF_8.decode(bb); + CharSequence truncated = UnicodeUtil.truncateString(charSequence, width); + ByteBuffer truncatedBytes = StandardCharsets.UTF_8.encode(CharBuffer.wrap(truncated)); + return UTF8String.fromBytes(ByteBuffers.toByteArray(truncatedBytes)); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public DataType resultType() { + return DataTypes.StringType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](string)"; + } + + @Override + public String produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + UTF8String toTruncate = !input.isNullAt(1) ? input.getUTF8String(1) : null; + UTF8String result = toTruncate != null ? invoke(width, toTruncate) : null; + return result != null ? result.toString() : null; + } + } + + public static class TruncateBinary extends TruncateBase { + // magic method used in codegen + public static byte[] invoke(int width, byte[] value) { + if (value == null) { + return null; + } + + return ByteBuffers.toByteArray( + TruncateUtil.truncateByteBuffer(width, ByteBuffer.wrap(value))); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public DataType resultType() { + return DataTypes.BinaryType; + } + + @Override + public String canonicalName() { + return "org.apache.iceberg.spark.functions.truncate[width](binary)"; + } + + @Override + public byte[] produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + byte[] toTruncate = !input.isNullAt(1) ? input.getBinary(1) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + public static class TruncateDecimal extends TruncateBase { + private final int precision; + private final int scale; + + public TruncateDecimal(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + // magic method used in codegen + public static Decimal invoke(int width, Decimal value) { + if (value == null) { + return null; + } + + return Decimal.apply( + TruncateUtil.truncateDecimal(BigInteger.valueOf(width), value.toJavaBigDecimal())); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.createDecimalType(precision, scale)}; + } + + @Override + public DataType resultType() { + return DataTypes.createDecimalType(precision, scale); + } + + @Override + public String canonicalName() { + return String.format( + "org.apache.iceberg.spark.functions.truncate[width](decimal(%d,%d))", precision, scale); + } + + @Override + public Decimal produceResult(InternalRow input) { + Integer width = readAndValidateWidth(input); + + Decimal toTruncate = !input.isNullAt(1) ? input.getDecimal(1, precision, scale) : null; + return toTruncate != null ? invoke(width, toTruncate) : null; + } + } + + private static Integer readAndValidateWidth(InternalRow input) { + Integer width = !input.isNullAt(0) ? input.getInt(0) : null; + if (width == null) { + throw new IllegalArgumentException("Invalid truncation width: null"); + } + + if (width <= 0) { + throw new IllegalArgumentException( + String.format("Invalid truncate width: %s (must be > 0)", width)); + } + + return width; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java new file mode 100644 index 000000000000..31a8a8494694 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/TestFunctionCatalog.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.functions.SparkFunctions; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.FunctionCatalog; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import scala.collection.JavaConverters; + +@RunWith(Parameterized.class) +public class TestFunctionCatalog extends SparkCatalogTestBase { + @Parameterized.Parameters(name = "catalogConfig = {0}") + public static Object[][] parameters() { + return new Object[][] { + {SparkCatalogConfig.HADOOP}, {SparkCatalogConfig.HIVE}, {SparkCatalogConfig.SPARK} + }; + } + + private static final Namespace NS = Namespace.of("db"); + private final boolean isSessionCatalog; + private final String fullNamespace; + private final FunctionCatalog asFunctionCatalog; + + public TestFunctionCatalog(SparkCatalogConfig catalogConfig) { + super(catalogConfig); + this.isSessionCatalog = "spark_catalog".equals(catalogName); + this.fullNamespace = (isSessionCatalog ? "" : catalogName + ".") + NS; + this.asFunctionCatalog = castToFunctionCatalog(catalogName); + } + + @Before + public void createNamespace() { + sql("CREATE NAMESPACE IF NOT EXISTS %s", fullNamespace); + } + + @After + public void cleanNamespaces() { + sql("DROP NAMESPACE IF EXISTS %s", fullNamespace); + } + + @Test + public void testLoadListAndUseFunctionsFromSystemNamespace() + throws NoSuchFunctionException, NoSuchNamespaceException { + // TODO - Remove this assumption when the SparkSessionCatalog can resolve functions in the + // `system` namespace. + Assume.assumeFalse( + "The session catalog cannot use functions via the `system` namespace", isSessionCatalog); + String[] namespace = {"system"}; + String name = "truncate"; + Identifier identifier = Identifier.of(namespace, name); + + assertListingLoadingAndBindingFrom(identifier); + } + + @Test + public void testLoadListAndUseFunctionsFromEmptyNamespace() + throws NoSuchFunctionException, NoSuchNamespaceException { + String[] namespace = {}; + String name = "truncate"; + Identifier identifier = Identifier.of(namespace, name); + + assertListingLoadingAndBindingFrom(identifier); + } + + @Test + public void testCannotLoadFunctionsFromInvalidNamespace() { + AssertHelpers.assertThrows( + "Function Catalog functions should only be accessible from the system namespace and empty namespace", + AnalysisException.class, + "Undefined function", + () -> sql("SELECT %s.truncate(1, 2)", fullNamespace)); + } + + @Test + public void testCannotUseUndefinedFunction() { + AssertHelpers.assertThrows( + "Using an undefined function should throw", + AnalysisException.class, + "Undefined function", + () -> sql("SELECT undefined_function(1, 2)")); + } + + private void assertListingLoadingAndBindingFrom(Identifier identifier) + throws NoSuchNamespaceException, NoSuchFunctionException { + String[] namespace = identifier.namespace(); + + Assert.assertTrue( + "The function catalog only allows using the namespace `system` or an empty namespace", + namespace.length == 0 + || (namespace.length == 1 && namespace[0].equalsIgnoreCase("system"))); + + // Load + UnboundFunction unboundFunction = asFunctionCatalog.loadFunction(identifier); + Assert.assertNotNull( + identifier + " function should be loadable via the FunctionCatalog", unboundFunction); + + // List + Identifier[] identifiers = asFunctionCatalog.listFunctions(namespace); + Assert.assertTrue( + String.format( + "Functions listed from the %s namespace should not be empty", + Arrays.toString(namespace)), + identifiers.length > 0); + List functionNames = + Arrays.stream(identifiers).map(Identifier::name).collect(Collectors.toList()); + Assertions.assertThat(functionNames).hasSameElementsAs(SparkFunctions.list()); + + // Bind - assumes truncate function is used + ScalarFunction boundTruncate = + (ScalarFunction) + unboundFunction.bind( + new StructType() + .add("width", DataTypes.IntegerType) + .add("value", DataTypes.IntegerType)); + + Object width = Integer.valueOf(10); + Object toTruncate = Integer.valueOf(9); + Assert.assertEquals( + String.format( + "Binding the %s function from the function catalog should produce a usable function", + identifier), + Integer.valueOf(0), + boundTruncate.produceResult( + InternalRow.fromSeq( + JavaConverters.asScalaBufferConverter(ImmutableList.of(width, toTruncate)) + .asScala() + .toSeq()))); + } + + private FunctionCatalog castToFunctionCatalog(String name) { + return (FunctionCatalog) spark.sessionState().catalogManager().catalog(name); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java new file mode 100644 index 000000000000..292a045c27f2 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkTruncateFunction.java @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.AnalysisException; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestSparkTruncateFunction extends SparkTestBaseWithCatalog { + + // TODO - Add tests for SparkCatalogConfig.SPARK once the `system` namespace is resolvable from + // the session catalog. + @Parameterized.Parameters(name = "catalogConfig = {0}") + public static Object[][] parameters() { + return new Object[][] {{SparkCatalogConfig.HADOOP}, {SparkCatalogConfig.HIVE}}; + } + + private static final Namespace SYSTEM = Namespace.of("system"); + + private final String systemNamespace; + private final boolean isSessionCatalog; + + public TestSparkTruncateFunction(SparkCatalogConfig catalogConfig) { + super(catalogConfig); + this.isSessionCatalog = "spark_catalog".equals(catalogName); + this.systemNamespace = (isSessionCatalog ? "" : catalogName + ".") + SYSTEM; + } + + @Test + public void testTruncateUsingSystemNamespaceForNonSessionCatalogs() { + // Non-session catalogs use v2 function resolution always + Assume.assumeFalse(isSessionCatalog); + + Assert.assertEquals( + "Should be able to use the truncate function with the system namespace, for non-session catalogs", + 5, + scalarSql("SELECT %s.system.truncate(5, 6)", catalogName)); + } + + @Test + @Ignore // TODO - Return to this once session catalog is supported. + public void testTruncateUsingSystemNamespaceSessionCatalogs() { + // Spark's Session catalog only allows using new functions from a registered namespace. + Assume.assumeTrue(isSessionCatalog); + + Assert.assertEquals( + "Should be able to call system.truncate from session catalog, provided that we're in an Iceberg namespace", + 5, + scalarSql("SELECT %s.truncate(5, 6)", systemNamespace)); + + // Note session catalog can only use special `system` namespace if it's not qualified with + // catalog or db name. + AssertHelpers.assertThrows( + "Session catalog cannot be qualified when using system", + AnalysisException.class, + "Undefined function", + () -> scalarSql("SELECT spark_catalog.system.truncate(5, 6)")); + + AssertHelpers.assertThrows( + "Session catalog only allows usage of system keyword when used on its own", + AnalysisException.class, + "Undefined function", + () -> scalarSql("SELECT system.truncate(6, 5)")); + } + + @Test + public void testTruncateTinyInt() { + Assert.assertEquals((byte) 0, scalarSql("SELECT %s.truncate(10, 0Y)", systemNamespace)); + Assert.assertEquals((byte) 0, scalarSql("SELECT %s.truncate(10, 1Y)", systemNamespace)); + Assert.assertEquals((byte) 0, scalarSql("SELECT %s.truncate(10, 5Y)", systemNamespace)); + Assert.assertEquals((byte) 0, scalarSql("SELECT %s.truncate(10, 9Y)", systemNamespace)); + Assert.assertEquals((byte) 10, scalarSql("SELECT %s.truncate(10, 10Y)", systemNamespace)); + Assert.assertEquals((byte) 10, scalarSql("SELECT %s.truncate(10, 11Y)", systemNamespace)); + Assert.assertEquals((byte) -10, scalarSql("SELECT %s.truncate(10, -1Y)", systemNamespace)); + Assert.assertEquals((byte) -10, scalarSql("SELECT %s.truncate(10, -5Y)", systemNamespace)); + Assert.assertEquals((byte) -10, scalarSql("SELECT %s.truncate(10, -10Y)", systemNamespace)); + Assert.assertEquals((byte) -20, scalarSql("SELECT %s.truncate(10, -11Y)", systemNamespace)); + + // Check that different widths can be used + Assert.assertEquals((byte) -2, scalarSql("SELECT %s.truncate(2, -1Y)", systemNamespace)); + + // Check that tinyint types are allowed for the width + Assert.assertEquals((byte) 0, scalarSql("SELECT %s.truncate(5Y, 1Y)", systemNamespace)); + + Assert.assertEquals( + "Null input should return null", + null, + scalarSql("SELECT %s.truncate(2, CAST(null AS tinyint))", systemNamespace)); + } + + @Test + public void testTruncateSmallInt() { + Assert.assertEquals((short) 0, scalarSql("SELECT %s.truncate(10, 0S)", systemNamespace)); + Assert.assertEquals((short) 0, scalarSql("SELECT %s.truncate(10, 1S)", systemNamespace)); + Assert.assertEquals((short) 0, scalarSql("SELECT %s.truncate(10, 5S)", systemNamespace)); + Assert.assertEquals((short) 0, scalarSql("SELECT %s.truncate(10, 9S)", systemNamespace)); + Assert.assertEquals((short) 10, scalarSql("SELECT %s.truncate(10, 10S)", systemNamespace)); + Assert.assertEquals((short) 10, scalarSql("SELECT %s.truncate(10, 11S)", systemNamespace)); + Assert.assertEquals((short) -10, scalarSql("SELECT %s.truncate(10, -1S)", systemNamespace)); + Assert.assertEquals((short) -10, scalarSql("SELECT %s.truncate(10, -5S)", systemNamespace)); + Assert.assertEquals((short) -10, scalarSql("SELECT %s.truncate(10, -10S)", systemNamespace)); + Assert.assertEquals((short) -20, scalarSql("SELECT %s.truncate(10, -11S)", systemNamespace)); + + // Check that different widths can be used + Assert.assertEquals((short) -2, scalarSql("SELECT %s.truncate(2, -1S)", systemNamespace)); + + // Check that short types are allowed for the width + Assert.assertEquals((short) 0, scalarSql("SELECT %s.truncate(5S, 1S)", systemNamespace)); + + Assert.assertEquals( + "Null input should return null", + null, + scalarSql("SELECT %s.truncate(2, CAST(null AS smallint))", systemNamespace)); + } + + @Test + public void testTruncateIntegerLiteralSQL() { + Assert.assertEquals(0, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 0)); + Assert.assertEquals(0, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 1)); + Assert.assertEquals(0, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 5)); + Assert.assertEquals(0, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 9)); + Assert.assertEquals(10, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 10)); + Assert.assertEquals(10, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, 11)); + Assert.assertEquals(-10, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, -1)); + Assert.assertEquals(-10, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, -5)); + Assert.assertEquals(-10, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, -10)); + Assert.assertEquals(-20, scalarSql("SELECT %s.truncate(10, %d)", systemNamespace, -11)); + + // Check that different widths can be used + Assert.assertEquals(-2, scalarSql("SELECT %s.truncate(2, %d)", systemNamespace, -1)); + Assert.assertEquals(0, scalarSql("SELECT %s.truncate(2, %d)", systemNamespace, 1)); + + Assert.assertEquals( + "Null input should return null", + null, + scalarSql("SELECT %s.truncate(2, CAST(null AS int))", systemNamespace)); + } + + @Test + public void testTruncateBigInt() { + Assert.assertEquals(0L, scalarSql("SELECT %s.truncate(10, 0L)", systemNamespace)); + Assert.assertEquals(0L, scalarSql("SELECT %s.truncate(10, 1L)", systemNamespace)); + Assert.assertEquals(0L, scalarSql("SELECT %s.truncate(10, 5L)", systemNamespace)); + Assert.assertEquals(0L, scalarSql("SELECT %s.truncate(10, 9L)", systemNamespace)); + Assert.assertEquals(10L, scalarSql("SELECT %s.truncate(10, 10L)", systemNamespace)); + Assert.assertEquals(10L, scalarSql("SELECT %s.truncate(10, 11L)", systemNamespace)); + Assert.assertEquals(-10L, scalarSql("SELECT %s.truncate(10, -1L)", systemNamespace)); + Assert.assertEquals(-10L, scalarSql("SELECT %s.truncate(10, -5L)", systemNamespace)); + Assert.assertEquals(-10L, scalarSql("SELECT %s.truncate(10, -10L)", systemNamespace)); + Assert.assertEquals(-20L, scalarSql("SELECT %s.truncate(10, -11L)", systemNamespace)); + + // Check that different widths can be used + Assert.assertEquals(-2L, scalarSql("SELECT %s.truncate(2, -1L)", systemNamespace)); + + Assert.assertEquals( + "Null input should return null", + null, + scalarSql("SELECT %s.truncate(2, CAST(null AS bigint))", systemNamespace)); + } + + @Test + public void testTruncateDecimalLiteralSQL() { + // decimal truncation works by applying the decimal scale to the width: ie 10 scale 2 = 0.10 + Assert.assertEquals( + new BigDecimal("12.30"), + scalarSql("SELECT %s.truncate(10, CAST(%f as DECIMAL(9, 2)))", systemNamespace, 12.34)); + + Assert.assertEquals( + new BigDecimal("12.30"), + scalarSql("SELECT %s.truncate(10, CAST(%f as DECIMAL(9, 2)))", systemNamespace, 12.30)); + + Assert.assertEquals( + new BigDecimal("12.290"), + scalarSql("SELECT %s.truncate(10, CAST(%f as DECIMAL(9, 3)))", systemNamespace, 12.299)); + + Assert.assertEquals( + new BigDecimal("0.03"), + scalarSql("SELECT %s.truncate(3, CAST(%f as DECIMAL(5, 2)))", systemNamespace, 0.05)); + + Assert.assertEquals( + new BigDecimal("0.00"), + scalarSql("SELECT %s.truncate(10, CAST(%f as DECIMAL(9, 2)))", systemNamespace, 0.05)); + + Assert.assertEquals( + new BigDecimal("-0.10"), + scalarSql("SELECT %s.truncate(10, CAST(%f as DECIMAL(9, 2)))", systemNamespace, -0.05)); + + Assert.assertEquals( + "Implicit decimal scale and precision should be allowed", + new BigDecimal("12345.3480"), + scalarSql("SELECT %s.truncate(10, 12345.3482)", systemNamespace)); + + Assert.assertEquals( + "Null input should return null", + null, + scalarSql("SELECT %s.truncate(2, CAST(null AS decimal))", systemNamespace)); + } + + @Test + public void testInvalidTypesForWidthFailFunctionBinding() { + AssertHelpers.assertThrows( + "Decimal type should not be coercible to the width field", + AnalysisException.class, + "Expected truncation width to be one of [ByteType, ShortType, IntegerType]", + () -> + scalarSql( + "SELECT %s.truncate(CAST(12.34 as DECIMAL(9, 2)), 10)", systemNamespace, 12.34)); + + AssertHelpers.assertThrows( + "String type should not be coercible to the width field", + AnalysisException.class, + "Expected truncation width to be one of [ByteType, ShortType, IntegerType]", + () -> scalarSql("SELECT %s.truncate('5', 10)", systemNamespace)); + + AssertHelpers.assertThrows( + "Interval year to month type should not be coercible to the width field", + AnalysisException.class, + "Expected truncation width to be one of [ByteType, ShortType, IntegerType]", + () -> + scalarSql("SELECT %s.truncate(INTERVAL '100-00' YEAR TO MONTH, 10)", systemNamespace)); + + AssertHelpers.assertThrows( + "Interval day-time type should not be coercible to the width field", + AnalysisException.class, + "Expected truncation width to be one of [ByteType, ShortType, IntegerType]", + () -> + scalarSql( + "SELECT %s.truncate(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)", + systemNamespace)); + } + + @Test + public void testTruncateString() { + Assert.assertEquals( + "Should system.truncate strings longer than length", + "abcde", + scalarSql("SELECT %s.truncate(5, 'abcdefg')", systemNamespace)); + Assert.assertEquals( + "Should not pad strings shorter than length", + "abc", + scalarSql("SELECT %s.truncate(5, 'abc')", systemNamespace)); + Assert.assertEquals( + "Should not alter strings equal to length", + "abcde", + scalarSql("SELECT %s.truncate(5, 'abcde')", systemNamespace)); + Assert.assertEquals( + "Should handle three-byte UTF-8 characters appropriately", + "测", + scalarSql("SELECT %s.truncate(1, '测试')", systemNamespace)); + Assert.assertEquals( + "Should handle three-byte UTF-8 characters mixed with two byte utf-8 characters", + "测试ra", + scalarSql("SELECT %s.truncate(4, '测试raul试测')", systemNamespace)); + + Assert.assertEquals( + "Null input should return null as output", + null, + scalarSql("SELECT %s.truncate(3, CAST(null AS string))", systemNamespace)); + + Assert.assertEquals( + "Varchar should work like string", + "测试ra", + scalarSql("SELECT %s.truncate(4, CAST('测试raul试测' AS varchar(8)))", systemNamespace)); + + Assert.assertEquals( + "Char should work like string", + "测试ra", + scalarSql("SELECT %s.truncate(4, CAST('测试raul试测' AS char(8)))", systemNamespace)); + } + + @Test + public void testTruncateBinary() { + Assert.assertArrayEquals( + new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + (byte[]) + scalarSql( + "SELECT %s.truncate(10, X'0102030405060708090a0b0c0d0e0f')", systemNamespace)); + Assert.assertArrayEquals( + "Should return the same input when value is equal to truncation width", + "abc".getBytes(StandardCharsets.UTF_8), + (byte[]) + scalarSql("SELECT %s.truncate(3, %s)", systemNamespace, asBytesLiteral("abcdefg"))); + Assert.assertArrayEquals( + "Should not truncate, pad, or trim the input when its length is less than the width", + "abc\0\0".getBytes(StandardCharsets.UTF_8), + (byte[]) + scalarSql("SELECT %s.truncate(10, %s)", systemNamespace, asBytesLiteral("abc\0\0"))); + Assert.assertArrayEquals( + "Should not pad the input when its length is equal to the width", + "abc".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT %s.truncate(3, %s)", systemNamespace, asBytesLiteral("abc"))); + Assert.assertArrayEquals( + "Should handle three-byte UTF-8 characters appropriately", + "测试".getBytes(StandardCharsets.UTF_8), + (byte[]) scalarSql("SELECT %s.truncate(6, %s)", systemNamespace, asBytesLiteral("测试_"))); + + Assert.assertEquals( + "Null input should return null as output", + null, + scalarSql("SELECT %s.truncate(3, CAST(null AS binary))", systemNamespace)); + } + + @Test + public void testTruncateUsingDataframeForWidthWithVaryingWidth() { + // This situation is atypical but allowed. Typically width is static + long rumRows = 10L; + long numNonZero = + spark + .range(rumRows) + .toDF("value") + .selectExpr("CAST(value +1 AS INT) AS width", "value") + .selectExpr( + String.format("%s.truncate(width, value) as truncated_value", systemNamespace)) + .filter("truncated_value == 0") + .count(); + Assert.assertEquals( + "A truncate function with variable widths should be usable on dataframe columns", + rumRows, + numNonZero); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + // Magic functions have staticinvoke in the explain output. Nonmagic calls have + // applyfunctionexpression instead. + // TinyInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select %s.truncate(5, 6Y)", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateTinyInt"); + + // SmallInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select %s.truncate(5, 6S)", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateSmallInt"); + + // Int + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED select %s.truncate(5, 6)", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + // Long + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT %s.truncate(5, 6L)", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + + // String + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT %s.truncate(5, 'abcdefg')", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateString"); + + // Decimal + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT %s.truncate(5, 12.34)", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateDecimal"); + + // Binary + Assertions.assertThat( + scalarSql( + "EXPLAIN EXTENDED SELECT %s.truncate(4, X'0102030405060708')", systemNamespace)) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBinary"); + } + + @Test + public void testMagicFunctionsResolveForTinyIntAndSmallIntWidths() { + // Magic functions have staticinvoke in the explain output. Nonmagic calls use + // applyfunctionexpression instead. + String tinyIntWidthExplain = + (String) scalarSql("EXPLAIN EXTENDED SELECT %s.truncate(1Y, 6)", systemNamespace); + Assertions.assertThat(tinyIntWidthExplain) + .contains("cast(1 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateInt"); + + String smallIntWidth = + (String) scalarSql("EXPLAIN EXTENDED SELECT %s.truncate(5S, 6L)", systemNamespace); + Assertions.assertThat(smallIntWidth) + .contains("cast(5 as int)") + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.TruncateFunction$TruncateBigInt"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +}