Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8
Expand Down Expand Up @@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
private lazy val decimalConversions = new DecimalConversion()

private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
catalystType match {
case NullType =>
(catalystType, avroType.getType) match {
case (NullType, NULL) =>
(getter, ordinal) => null
case BooleanType =>
case (BooleanType, BOOLEAN) =>
(getter, ordinal) => getter.getBoolean(ordinal)
case ByteType =>
case (ByteType, INT) =>
(getter, ordinal) => getter.getByte(ordinal).toInt
case ShortType =>
case (ShortType, INT) =>
(getter, ordinal) => getter.getShort(ordinal).toInt
case IntegerType =>
case (IntegerType, INT) =>
(getter, ordinal) => getter.getInt(ordinal)
case LongType =>
case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
case FloatType =>
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case DoubleType =>
case (DoubleType, DOUBLE) =>
Copy link
Copy Markdown
Member

@dbtsai dbtsai Aug 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow users to do casting up from catalystType to avroType? For example, catalystType float to avroType double. If so, this can be done in different PR.

Copy link
Copy Markdown
Member Author

@gengliangwang gengliangwang Aug 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would like to keep it simple as this PR proposes.
If data type casting needed, users can always do it in DataFrame before writing Avro files.

But if the casting is important, we can work on it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if someone feels it's important, let's do it in different PR.

(getter, ordinal) => getter.getDouble(ordinal)
case d: DecimalType =>
case (d: DecimalType, FIXED)
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
(getter, ordinal) =>
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
LogicalTypes.decimal(d.precision, d.scale))

case StringType => avroType.getType match {
case Type.ENUM =>
import scala.collection.JavaConverters._
val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
(getter, ordinal) =>
val data = getter.getUTF8String(ordinal).toString
if (!enumSymbols.contains(data)) {
throw new IncompatibleSchemaException(
"Cannot write \"" + data + "\" since it's not defined in enum \"" +
enumSymbols.mkString("\", \"") + "\"")
}
new EnumSymbol(avroType, data)
case _ =>
(getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
}
case BinaryType => avroType.getType match {
case Type.FIXED =>
val size = avroType.getFixedSize()
(getter, ordinal) =>
val data: Array[Byte] = getter.getBinary(ordinal)
if (data.length != size) {
throw new IncompatibleSchemaException(
s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
"binary data into FIXED Type with size of " +
s"$size ${if (size > 1) "bytes" else "byte"}")
}
new Fixed(avroType, data)
case _ =>
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
}
case DateType =>
case (d: DecimalType, BYTES)
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
(getter, ordinal) =>
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
LogicalTypes.decimal(d.precision, d.scale))

case (StringType, ENUM) =>
val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
(getter, ordinal) =>
val data = getter.getUTF8String(ordinal).toString
if (!enumSymbols.contains(data)) {
throw new IncompatibleSchemaException(
"Cannot write \"" + data + "\" since it's not defined in enum \"" +
enumSymbols.mkString("\", \"") + "\"")
}
new EnumSymbol(avroType, data)

case (StringType, STRING) =>
(getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)

case (BinaryType, FIXED) =>
val size = avroType.getFixedSize()
(getter, ordinal) =>
val data: Array[Byte] = getter.getBinary(ordinal)
if (data.length != size) {
throw new IncompatibleSchemaException(
s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
"binary data into FIXED Type with size of " +
s"$size ${if (size > 1) "bytes" else "byte"}")
}
new Fixed(avroType, data)

case (BinaryType, BYTES) =>
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))

case (DateType, INT) =>
(getter, ordinal) => getter.getInt(ordinal)
case TimestampType => avroType.getLogicalType match {

case (TimestampType, LONG) => avroType.getLogicalType match {
case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
// For backward compatibility, if the Avro type is Long and it is not logical type,
Expand All @@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
}

case ArrayType(et, containsNull) =>
case (ArrayType(et, containsNull), ARRAY) =>
val elementConverter = newConverter(
et, resolveNullableType(avroType.getElementType, containsNull))
(getter, ordinal) => {
Expand All @@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
java.util.Arrays.asList(result: _*)
}

case st: StructType =>
case (st: StructType, RECORD) =>
val structConverter = newStructConverter(st, avroType)
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))

case MapType(kt, vt, valueContainsNull) if kt == StringType =>
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull))
(getter, ordinal) =>
Expand All @@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
result

case other =>
throw new IncompatibleSchemaException(s"Unexpected type: $other")
throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " +
s"Avro type $avroType.")
}
}

private def newStructConverter(
catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = {
if (avroStruct.getType != RECORD) {
throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " +
s"Avro type $avroStruct.")
}
val avroFields = avroStruct.getFields
assert(avroFields.size() == catalystStruct.length)
val fieldConverters = catalystStruct.zip(avroFields.asScala).map {
Expand All @@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
}

private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
if (nullable) {
if (nullable && avroType.getType != NULL) {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a trivial bug if avroType is NULL type.

// avro uses union to represent nullable type.
val fields = avroType.getTypes.asScala
assert(fields.length == 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,46 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU
}
}

test("Logical type: write Decimal with BYTES type") {
val specifiedSchema = """
{
"type" : "record",
"name" : "topLevelRecord",
"namespace" : "topLevelRecord",
"fields" : [ {
"name" : "bytes",
"type" : [ {
"type" : "bytes",
"namespace" : "topLevelRecord.bytes",
"logicalType" : "decimal",
"precision" : 4,
"scale" : 2
}, "null" ]
}, {
"name" : "fixed",
"type" : [ {
"type" : "bytes",
"logicalType" : "decimal",
"precision" : 4,
"scale" : 2
}, "null" ]
} ]
}
"""
withTempDir { dir =>
val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath)
assert(specifiedSchema != avroSchema)
val expected =
decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) }
val df = spark.read.format("avro").load(avroFile)

withTempPath { path =>
df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString)
checkAnswer(spark.read.format("avro").load(path.toString), expected)
}
}
}

test("Logical type: Decimal with too large precision") {
withTempDir { dir =>
val schema = new Schema.Parser().parse("""{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._

import org.apache.avro.Schema
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.Schema.Type._
import org.apache.avro.file.{DataFileReader, DataFileWriter}
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
Expand Down Expand Up @@ -850,6 +851,62 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}

test("throw exception if unable to write with user provided Avro schema") {
val input: Seq[(DataType, Schema.Type)] = Seq(
(NullType, NULL),
(BooleanType, BOOLEAN),
(ByteType, INT),
(ShortType, INT),
(IntegerType, INT),
(LongType, LONG),
(FloatType, FLOAT),
(DoubleType, DOUBLE),
(BinaryType, BYTES),
(DateType, INT),
(TimestampType, LONG),
(DecimalType(4, 2), BYTES)
)
def assertException(f: () => AvroSerializer) {
val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] {
f()
}.getMessage
assert(message.contains("Cannot convert Catalyst type"))
}

def resolveNullable(schema: Schema, nullable: Boolean): Schema = {
if (nullable && schema.getType != NULL) {
Schema.createUnion(schema, Schema.create(NULL))
} else {
schema
}
}
for {
i <- input
j <- input
nullable <- Seq(true, false)
} if (i._2 != j._2) {
val avroType = resolveNullable(Schema.create(j._2), nullable)
val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable)
val avroMapType = resolveNullable(Schema.createMap(avroType), nullable)
val name = "foo"
val avroField = new Field(name, avroType, "", null)
val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava)
val avroRecordType = resolveNullable(recordSchema, nullable)

val catalystType = i._1
val catalystArrayType = ArrayType(catalystType, nullable)
val catalystMapType = MapType(StringType, catalystType, nullable)
val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable)))

for {
avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType)
catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType)
} {
assertException(() => new AvroSerializer(catalyst, avro, nullable))
}
}
}

test("reading from invalid path throws exception") {

// Directory given has no avro files
Expand Down