Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ project(':iceberg-spark') {
}
testCompile project(path: ':iceberg-hive', configuration: 'testArtifacts')
testCompile project(path: ':iceberg-api', configuration: 'testArtifacts')
testCompile project(path: ':iceberg-data', configuration: 'testArtifacts')
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public OrcValueWriter<?> primitive(Type.PrimitiveType iPrimitive, TypeDescriptio
return GenericOrcWriters.byteBuffers();
case DECIMAL:
Types.DecimalType decimalType = (Types.DecimalType) iPrimitive;
return GenericOrcWriters.decimal(decimalType.scale(), decimalType.precision());
return GenericOrcWriters.decimal(decimalType.precision(), decimalType.scale());
default:
throw new IllegalArgumentException(String.format("Invalid iceberg type %s corresponding to ORC type %s",
iPrimitive, primitive));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.orc.OrcValueWriter;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.orc.storage.common.type.HiveDecimal;
import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
Expand Down Expand Up @@ -103,11 +104,13 @@ public static OrcValueWriter<LocalDateTime> timestamp() {
return TimestampWriter.INSTANCE;
}

public static OrcValueWriter<BigDecimal> decimal(int scale, int precision) {
public static OrcValueWriter<BigDecimal> decimal(int precision, int scale) {
if (precision <= 18) {
return new Decimal18Writer(scale);
return new Decimal18Writer(precision, scale);
} else if (precision <= 38) {
return new Decimal38Writer(precision, scale);
} else {
return Decimal38Writer.INSTANCE;
throw new IllegalArgumentException("Invalid precision: " + precision);
}
}

Expand Down Expand Up @@ -288,8 +291,10 @@ public Class<OffsetDateTime> getJavaClass() {
@Override
public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
cv.time[rowId] = data.toInstant().toEpochMilli(); // millis
cv.nanos[rowId] = (data.getNano() / 1_000) * 1_000; // truncate nanos to only keep microsecond precision
// millis
cv.time[rowId] = data.toInstant().toEpochMilli();
// truncate nanos to only keep microsecond precision
cv.nanos[rowId] = data.getNano() / 1_000 * 1_000;
}
}

Expand All @@ -311,9 +316,11 @@ public void nonNullWrite(int rowId, LocalDateTime data, ColumnVector output) {
}

private static class Decimal18Writer implements OrcValueWriter<BigDecimal> {
private final int precision;
private final int scale;

Decimal18Writer(int scale) {
Decimal18Writer(int precision, int scale) {
this.precision = precision;
this.scale = scale;
}

Expand All @@ -324,14 +331,24 @@ public Class<BigDecimal> getJavaClass() {

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
// TODO: validate precision and scale from schema
Preconditions.checkArgument(data.scale() == scale,
"Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data);
Preconditions.checkArgument(data.precision() <= precision,
"Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data);

((DecimalColumnVector) output).vector[rowId]
.setFromLongAndScale(data.unscaledValue().longValueExact(), scale);
}
}

private static class Decimal38Writer implements OrcValueWriter<BigDecimal> {
private static final OrcValueWriter<BigDecimal> INSTANCE = new Decimal38Writer();
private final int precision;
private final int scale;

Decimal38Writer(int precision, int scale) {
this.precision = precision;
this.scale = scale;
}

@Override
public Class<BigDecimal> getJavaClass() {
Expand All @@ -340,7 +357,11 @@ public Class<BigDecimal> getJavaClass() {

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
// TODO: validate precision and scale from schema
Preconditions.checkArgument(data.scale() == scale,
"Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, data);
Preconditions.checkArgument(data.precision() <= precision,
"Cannot write value as decimal(%s,%s), invalid precision: %s", precision, scale, data);

((DecimalColumnVector) output).vector[rowId].set(HiveDecimal.create(data, false));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import org.apache.iceberg.orc.OrcValueReader;
import org.apache.iceberg.orc.OrcValueReaders;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
Expand Down Expand Up @@ -57,8 +58,10 @@ public static OrcValueReader<Long> timestampTzs() {
public static OrcValueReader<Decimal> decimals(int precision, int scale) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return new SparkOrcValueReaders.Decimal18Reader(precision, scale);
} else {
} else if (precision <= 38) {
return new SparkOrcValueReaders.Decimal38Reader(precision, scale);
} else {
throw new IllegalArgumentException("Invalid precision: " + precision);
}
}

Expand Down Expand Up @@ -177,13 +180,12 @@ private TimestampTzReader() {

@Override
public Long nonNullRead(ColumnVector vector, int row) {
TimestampColumnVector timestampVector = (TimestampColumnVector) vector;
return (timestampVector.time[row] / 1000) * 1_000_000 + timestampVector.nanos[row] / 1000;
TimestampColumnVector tcv = (TimestampColumnVector) vector;
return (Math.floorDiv(tcv.time[row], 1_000)) * 1_000_000 + Math.floorDiv(tcv.nanos[row], 1000);
}
}

private static class Decimal18Reader implements OrcValueReader<Decimal> {
//TODO: these are being unused. check for bug
private final int precision;
private final int scale;

Expand All @@ -195,7 +197,15 @@ private static class Decimal18Reader implements OrcValueReader<Decimal> {
@Override
public Decimal nonNullRead(ColumnVector vector, int row) {
HiveDecimalWritable value = ((DecimalColumnVector) vector).vector[row];
return new Decimal().set(value.serialize64(value.scale()), value.precision(), value.scale());

// The scale of decimal read from hive ORC file may be not equals to the expected scale. For data type
// decimal(10,3) and the value 10.100, the hive ORC writer will remove its trailing zero and store it
// as 101*10^(-1), its scale will adjust from 3 to 1. So here we could not assert that value.scale() == scale.
// we also need to convert the hive orc decimal to a decimal with expected precision and scale.
Preconditions.checkArgument(value.precision() <= precision,
"Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need to check the precision either. If we read a value, then we should return it, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary to do this check. we need to make sure that there's no bug when written a decimal into ORC. For example, for decimal(3, 0) data type we encounter a hive decimal 10000 (whose precision is 5), that should be something wrong. Throwing an exception is the correct way in that case.


return new Decimal().set(value.serialize64(scale), precision, scale);
}
}

Expand All @@ -212,6 +222,10 @@ private static class Decimal38Reader implements OrcValueReader<Decimal> {
public Decimal nonNullRead(ColumnVector vector, int row) {
BigDecimal value = ((DecimalColumnVector) vector).vector[row]
.getHiveDecimal().bigDecimalValue();

Preconditions.checkArgument(value.precision() <= precision,
"Cannot read value as decimal(%s,%s), too large: %s", precision, scale, value);

return new Decimal().set(new scala.math.BigDecimal(value), precision, scale);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ private static class TimestampTzWriter implements SparkOrcValueWriter {
@Override
public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
long micros = data.getLong(column);
cv.time[rowId] = micros / 1_000; // millis
cv.nanos[rowId] = (int) (micros % 1_000_000) * 1_000; // nanos
long micros = data.getLong(column); // it could be negative.
cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis
cv.nanos[rowId] = (int) (Math.floorMod(micros, 1_000_000)) * 1_000; // nanos
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual
break;
case DATE:
Assert.assertTrue("Should expect a LocalDate", expected instanceof LocalDate);
long expectedDays = ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected);
int expectedDays = (int) ChronoUnit.DAYS.between(EPOCH_DAY, (LocalDate) expected);
Assert.assertEquals("Primitive value should be equal to expected", expectedDays, actual);
break;
case TIMESTAMP:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iceberg.spark.data;

import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Iterator;
import java.util.List;
import org.apache.iceberg.Files;
import org.apache.iceberg.Schema;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.data.RandomGenericData;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.data.orc.GenericOrcReader;
import org.apache.iceberg.data.orc.GenericOrcWriter;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
import org.junit.Assert;
import org.junit.Test;

import static org.apache.iceberg.types.Types.NestedField.required;

public class TestSparkRecordOrcReaderWriter extends AvroDataTest {
private static final int NUM_RECORDS = 200;

private void writeAndValidate(Schema schema, List<Record> expectedRecords) throws IOException {
final File originalFile = temp.newFile();
Assert.assertTrue("Delete should succeed", originalFile.delete());

// Write few generic records into the original test file.
try (FileAppender<Record> writer = ORC.write(Files.localOutput(originalFile))
.createWriterFunc(GenericOrcWriter::buildWriter)
.schema(schema)
.build()) {
writer.addAll(expectedRecords);
}

// Read into spark InternalRow from the original test file.
List<InternalRow> internalRows = Lists.newArrayList();
try (CloseableIterable<InternalRow> reader = ORC.read(Files.localInput(originalFile))
.project(schema)
.createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema))
.build()) {
reader.forEach(internalRows::add);
assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size());
}

final File anotherFile = temp.newFile();
Assert.assertTrue("Delete should succeed", anotherFile.delete());

// Write those spark InternalRows into a new file again.
try (FileAppender<InternalRow> writer = ORC.write(Files.localOutput(anotherFile))
.createWriterFunc(SparkOrcWriter::new)
.schema(schema)
.build()) {
writer.addAll(internalRows);
}

// Check whether the InternalRows are expected records.
try (CloseableIterable<InternalRow> reader = ORC.read(Files.localInput(anotherFile))
.project(schema)
.createReaderFunc(readOrcSchema -> new SparkOrcReader(schema, readOrcSchema))
.build()) {
assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size());
}

// Read into iceberg GenericRecord and check again.
try (CloseableIterable<Record> reader = ORC.read(Files.localInput(anotherFile))
.createReaderFunc(typeDesc -> GenericOrcReader.buildReader(schema, typeDesc))
.project(schema)
.build()) {
assertRecordEquals(expectedRecords, reader, expectedRecords.size());
}
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
List<Record> expectedRecords = RandomGenericData.generate(schema, NUM_RECORDS, 1992L);
writeAndValidate(schema, expectedRecords);
}

@Test
public void testDecimalWithTrailingZero() throws IOException {
Schema schema = new Schema(
required(1, "d1", Types.DecimalType.of(10, 2)),
required(2, "d2", Types.DecimalType.of(20, 5)),
required(3, "d3", Types.DecimalType.of(38, 20))
);

List<Record> expected = Lists.newArrayList();

GenericRecord record = GenericRecord.create(schema);
record.set(0, new BigDecimal("101.00"));
record.set(1, new BigDecimal("10.00E-3"));
record.set(2, new BigDecimal("1001.0000E-16"));

expected.add(record.copy());

writeAndValidate(schema, expected);
}

private static void assertRecordEquals(Iterable<Record> expected, Iterable<Record> actual, int size) {
Iterator<Record> expectedIter = expected.iterator();
Iterator<Record> actualIter = actual.iterator();
for (int i = 0; i < size; i += 1) {
Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext());
Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext());
Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next());
}
Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext());
Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext());
}

private static void assertEqualsUnsafe(Types.StructType struct, Iterable<Record> expected,
Iterable<InternalRow> actual, int size) {
Iterator<Record> expectedIter = expected.iterator();
Iterator<InternalRow> actualIter = actual.iterator();
for (int i = 0; i < size; i += 1) {
Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext());
Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext());
GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next());
}
Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext());
Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext());
}
}