diff --git a/build.gradle b/build.gradle index c20c528cf2c6..90074265eee9 100644 --- a/build.gradle +++ b/build.gradle @@ -510,6 +510,7 @@ project(':iceberg-spark') { } testCompile project(path: ':iceberg-hive', configuration: 'testArtifacts') testCompile project(path: ':iceberg-api', configuration: 'testArtifacts') + testCompile project(path: ':iceberg-data', configuration: 'testArtifacts') } test { diff --git a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java index 1ed013abe894..b55cb8c14f92 100644 --- a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java +++ b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java @@ -103,7 +103,7 @@ public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescriptio return GenericOrcWriters.byteBuffers(); case DECIMAL: Types.DecimalType decimalType = (Types.DecimalType) iPrimitive; - return GenericOrcWriters.decimal(decimalType.scale(), decimalType.precision()); + return GenericOrcWriters.decimal(decimalType.precision(), decimalType.scale()); default: throw new IllegalArgumentException(String.format("Invalid iceberg type %s corresponding to ORC type %s", iPrimitive, primitive)); diff --git a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java index 6103c1e3e8b7..12d70f5225ad 100644 --- a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java +++ b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java @@ -33,6 +33,7 @@ import java.util.Map; import java.util.UUID; import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.orc.storage.common.type.HiveDecimal; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; @@ -103,11 +104,13 @@ public static OrcValueWriter timestamp() { return TimestampWriter.INSTANCE; } - public static OrcValueWriter decimal(int scale, int precision) { + public static OrcValueWriter decimal(int precision, int scale) { if (precision <= 18) { - return new Decimal18Writer(scale); + return new Decimal18Writer(precision, scale); + } else if (precision <= 38) { + return new Decimal38Writer(precision, scale); } else { - return Decimal38Writer.INSTANCE; + throw new IllegalArgumentException("Invalid precision: " + precision); } } @@ -288,8 +291,10 @@ public Class getJavaClass() { @Override public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; - cv.time[rowId] = data.toInstant().toEpochMilli(); // millis - cv.nanos[rowId] = (data.getNano() / 1_000) * 1_000; // truncate nanos to only keep microsecond precision + // millis + cv.time[rowId] = data.toInstant().toEpochMilli(); + // truncate nanos to only keep microsecond precision + cv.nanos[rowId] = data.getNano() / 1_000 * 1_000; } } @@ -311,9 +316,11 @@ public void nonNullWrite(int rowId, LocalDateTime data, ColumnVector output) { } private static class Decimal18Writer implements OrcValueWriter { + private final int precision; private final int scale; - Decimal18Writer(int scale) { + Decimal18Writer(int precision, int scale) { + this.precision = precision; this.scale = scale; } @@ -324,14 +331,24 @@ public Class getJavaClass() { @Override public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) { - // TODO: validate precision and scale from schema + Preconditions.checkArgument(data.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data); + Preconditions.checkArgument(data.precision() <= precision, + "Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data); + ((DecimalColumnVector) output).vector[rowId] .setFromLongAndScale(data.unscaledValue().longValueExact(), scale); } } private static class Decimal38Writer implements OrcValueWriter { - private static final OrcValueWriter INSTANCE = new Decimal38Writer(); + private final int precision; + private final int scale; + + Decimal38Writer(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } @Override public Class getJavaClass() { @@ -340,7 +357,11 @@ public Class getJavaClass() { @Override public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) { - // TODO: validate precision and scale from schema + Preconditions.checkArgument(data.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data); + Preconditions.checkArgument(data.precision() <= precision, + "Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data); + ((DecimalColumnVector) output).vector[rowId].set(HiveDecimal.create(data, false)); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java index e7dfe7f43e11..70e8b08fc4a0 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -24,6 +24,7 @@ import java.util.Map; import org.apache.iceberg.orc.OrcValueReader; import org.apache.iceberg.orc.OrcValueReaders; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; @@ -57,8 +58,10 @@ public static OrcValueReader timestampTzs() { public static OrcValueReader decimals(int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { return new SparkOrcValueReaders.Decimal18Reader(precision, scale); - } else { + } else if (precision <= 38) { return new SparkOrcValueReaders.Decimal38Reader(precision, scale); + } else { + throw new IllegalArgumentException("Invalid precision: " + precision); } } @@ -177,13 +180,12 @@ private TimestampTzReader() { @Override public Long nonNullRead(ColumnVector vector, int row) { - TimestampColumnVector timestampVector = (TimestampColumnVector) vector; - return (timestampVector.time[row] / 1000) * 1_000_000 + timestampVector.nanos[row] / 1000; + TimestampColumnVector tcv = (TimestampColumnVector) vector; + return (Math.floorDiv(tcv.time[row], 1_000)) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000); } } private static class Decimal18Reader implements OrcValueReader { - //TODO: these are being unused. check for bug private final int precision; private final int scale; @@ -195,7 +197,15 @@ private static class Decimal18Reader implements OrcValueReader { @Override public Decimal nonNullRead(ColumnVector vector, int row) { HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row]; - return new Decimal().set(value.serialize64(value.scale()), value.precision(), value.scale()); + + // The scale of decimal read from hive ORC file may be not equals to the expected scale. For data type + // decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and store it + // as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that value.scale() == scale. + // we also need to convert the hive orc decimal to a decimal with expected precision and scale. + Preconditions.checkArgument(value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value); + + return new Decimal().set(value.serialize64(scale), precision, scale); } } @@ -212,6 +222,10 @@ private static class Decimal38Reader implements OrcValueReader { public Decimal nonNullRead(ColumnVector vector, int row) { BigDecimal value = ((DecimalColumnVector) vector).vector[row] .getHiveDecimal().bigDecimalValue(); + + Preconditions.checkArgument(value.precision() <= precision, + "Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value); + return new Decimal().set(new scala.math.BigDecimal(value), precision, scale); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java index 8bb0f53f83cb..9148b5a8a89f 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -183,9 +183,9 @@ private static class TimestampTzWriter implements SparkOrcValueWriter { @Override public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; - long micros = data.getLong(column); - cv.time[rowId] = micros / 1_000; // millis - cv.nanos[rowId] = (int) (micros % 1_000_000) * 1_000; // nanos + long micros = data.getLong(column); // it could be negative. + cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis + cv.nanos[rowId] = (int) (Math.floorMod(micros, 1_000_000)) * 1_000; // nanos } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java b/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java index 821f5bd66f9f..0c4598a209e8 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/GenericsHelpers.java @@ -237,7 +237,7 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual break; case DATE: Assert.assertTrue("Should expect a LocalDate", expected instanceof LocalDate); - long expectedDays = ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected); + int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected); Assert.assertEquals("Primitive value should be equal to expected", expectedDays, actual); break; case TIMESTAMP: diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java new file mode 100644 index 000000000000..1e7430d16df7 --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.util.Iterator; +import java.util.List; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.orc.GenericOrcReader; +import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.orc.ORC; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestSparkRecordOrcReaderWriter extends AvroDataTest { + private static final int NUM_RECORDS = 200; + + private void writeAndValidate(Schema schema, List expectedRecords) throws IOException { + final File originalFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", originalFile.delete()); + + // Write few generic records into the original test file. + try (FileAppender writer = ORC.write(Files.localOutput(originalFile)) + .createWriterFunc(GenericOrcWriter::buildWriter) + .schema(schema) + .build()) { + writer.addAll(expectedRecords); + } + + // Read into spark InternalRow from the original test file. + List internalRows = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(originalFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + reader.forEach(internalRows::add); + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + final File anotherFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", anotherFile.delete()); + + // Write those spark InternalRows into a new file again. + try (FileAppender writer = ORC.write(Files.localOutput(anotherFile)) + .createWriterFunc(SparkOrcWriter::new) + .schema(schema) + .build()) { + writer.addAll(internalRows); + } + + // Check whether the InternalRows are expected records. + try (CloseableIterable reader = ORC.read(Files.localInput(anotherFile)) + .project(schema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema)) + .build()) { + assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); + } + + // Read into iceberg GenericRecord and check again. + try (CloseableIterable reader = ORC.read(Files.localInput(anotherFile)) + .createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc)) + .project(schema) + .build()) { + assertRecordEquals(expectedRecords, reader, expectedRecords.size()); + } + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + List expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L); + writeAndValidate(schema, expectedRecords); + } + + @Test + public void testDecimalWithTrailingZero() throws IOException { + Schema schema = new Schema( + required(1, "d1", Types.DecimalType.of(10, 2)), + required(2, "d2", Types.DecimalType.of(20, 5)), + required(3, "d3", Types.DecimalType.of(38, 20)) + ); + + List expected = Lists.newArrayList(); + + GenericRecord record = GenericRecord.create(schema); + record.set(0, new BigDecimal("101.00")); + record.set(1, new BigDecimal("10.00E-3")); + record.set(2, new BigDecimal("1001.0000E-16")); + + expected.add(record.copy()); + + writeAndValidate(schema, expected); + } + + private static void assertRecordEquals(Iterable expected, Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } + + private static void assertEqualsUnsafe(Types.StructType struct, Iterable expected, + Iterable actual, int size) { + Iterator expectedIter = expected.iterator(); + Iterator actualIter = actual.iterator(); + for (int i = 0; i < size; i += 1) { + Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); + Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next()); + } + Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); + Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + } +}