Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8
Expand Down Expand Up @@ -86,8 +87,18 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
case (LONG, LongType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, TimestampType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case (LONG, TimestampType) => avroType.getLogicalType match {
case _: TimestampMillis => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case _: TimestampMicros => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case null => (updater, ordinal, value) =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, add a default case.

// For backward compatibility, if the Avro type is Long and it is not logical type,
// the value is processed as timestamp type with millisecond precision.
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case other => throw new IncompatibleSchemaException(
s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
}

case (LONG, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private[avro] class AvroFileFormat extends FileFormat
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf())
val outputAvroSchema = SchemaConverters.toAvroType(
dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace)
val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false,
parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType)

AvroJob.setOutputKeySchema(job, outputAvroSchema)
val COMPRESS_KEY = "mapred.output.compress"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType

/**
* Options for Avro Reader and Writer stored in case insensitive manner.
Expand Down Expand Up @@ -79,4 +80,14 @@ class AvroOptions(
val compression: String = {
parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec)
}

/**
* Avro timestamp type used when Spark writes data to Avro files.
* Currently supported types are `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS`.
* TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of microseconds
* from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with millisecond precision,
* which means Spark has to truncate the microsecond portion of its timestamp value.
* The related configuration is set via SQLConf, and it is not exposed as an option.
*/
val outputTimestampType: AvroOutputTimestampType.Value = SQLConf.get.avroOutputTimestampType
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._

import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type.NULL
import org.apache.avro.generic.GenericData.Record
Expand Down Expand Up @@ -92,8 +93,15 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
case DateType =>
(getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY
case TimestampType =>
(getter, ordinal) => getter.getLong(ordinal) / 1000
case TimestampType => avroType.getLogicalType match {
case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
// For backward compatibility, if the Avro type is Long and it is not logical type,
// output the timestamp value as with millisecond precision.
case null => (getter, ordinal) => getter.getLong(ordinal) / 1000
case other => throw new IncompatibleSchemaException(
s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
}

case ArrayType(et, containsNull) =>
val elementConverter = newConverter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.avro

import scala.collection.JavaConverters._

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType
import org.apache.spark.sql.types._

/**
Expand All @@ -42,7 +44,10 @@ object SchemaConverters {
case BYTES => SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => SchemaType(LongType, nullable = false)
case LONG => avroSchema.getLogicalType match {
case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
case _ => SchemaType(LongType, nullable = false)
}
case FIXED => SchemaType(BinaryType, nullable = false)
case ENUM => SchemaType(StringType, nullable = false)

Expand Down Expand Up @@ -103,31 +108,45 @@ object SchemaConverters {
catalystType: DataType,
nullable: Boolean = false,
recordName: String = "topLevelRecord",
prevNameSpace: String = ""): Schema = {
prevNameSpace: String = "",
outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS)
: Schema = {
val builder = if (nullable) {
SchemaBuilder.builder().nullable()
} else {
SchemaBuilder.builder()
}

catalystType match {
case BooleanType => builder.booleanType()
case ByteType | ShortType | IntegerType => builder.intType()
case LongType => builder.longType()
case DateType => builder.longType()
case TimestampType => builder.longType()
case TimestampType =>
val timestampType = outputTimestampType match {
case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis()
case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros()
case other =>
throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.")
}
builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong()

case FloatType => builder.floatType()
case DoubleType => builder.doubleType()
case _: DecimalType | StringType => builder.stringType()
case BinaryType => builder.bytesType()
case ArrayType(et, containsNull) =>
builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace))
builder.array()
.items(toAvroType(et, containsNull, recordName, prevNameSpace, outputTimestampType))
case MapType(StringType, vt, valueContainsNull) =>
builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace))
builder.map()
.values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType))
case st: StructType =>
val nameSpace = s"$prevNameSpace.$recordName"
val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
st.foreach { f =>
val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace)
val fieldAvroType =
toAvroType(f.dataType, f.nullable, f.name, nameSpace, outputTimestampType)
fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
}
fieldsAssembler.endRecord()
Expand Down
Binary file added external/avro/src/test/resources/timestamp.avro
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,34 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._

class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
import testImplicits._

val episodesAvro = testFile("episodes.avro")
val testAvro = testFile("test.avro")

// The test file timestamp.avro is generated via following Python code:
// import json
// import avro.schema
// from avro.datafile import DataFileWriter
// from avro.io import DatumWriter
//
// write_schema = avro.schema.parse(json.dumps({
// "namespace": "logical",
// "type": "record",
// "name": "test",
// "fields": [
// {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}},
// {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}},
// {"name": "long", "type": "long"}
// ]
// }))
//
// writer = DataFileWriter(open("timestamp.avro", "wb"), DatumWriter(), write_schema)
// writer.append({"timestamp_millis": 1000, "timestamp_micros": 2000000, "long": 3000})
// writer.append({"timestamp_millis": 666000, "timestamp_micros": 999000000, "long": 777000})
// writer.close()
val timestampAvro = testFile("timestamp.avro")

override protected def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
Expand Down Expand Up @@ -331,6 +356,77 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}

test("Logical type: timestamp_millis") {
val expected = Seq(1000L, 666000L).map(t => Row(new Timestamp(t)))
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: timestamp_micros") {
val expected = Seq(2000L, 999000L).map(t => Row(new Timestamp(t)))
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: specify different output timestamp types") {
val df =
spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros)

val expected = Seq((1000L, 2000L), (666000L, 999000L))
.map(t => Row(new Timestamp(t._1), new Timestamp(t._2)))

Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType =>
withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) {
withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}
}
}

test("Read Long type as Timestamp") {
val schema = StructType(StructField("long", TimestampType, true) :: Nil)
val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long)

val expected = Seq(3000L, 777000L).map(t => Row(new Timestamp(t)))

checkAnswer(df, expected)
}

test("Logical type: user specified schema") {
val expected = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L))
.map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3))

val avroSchema = s"""
{
"namespace": "logical",
"type": "record",
"name": "test",
"fields": [
{"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}},
{"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}},
{"name": "long", "type": "long"}
]
}
"""
val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro)

checkAnswer(df, expected)
}

test("Array data types") {
withTempPath { dir =>
val testSchema = StructType(Seq(
Expand Down Expand Up @@ -511,7 +607,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

// TimesStamps are converted to longs
val times = spark.read.format("avro").load(avroDir).select("Time").collect()
assert(times.map(_(0)).toSet == Set(666, 777, 42))
assert(times.map(_(0)).toSet ==
Set(new Timestamp(666), new Timestamp(777), new Timestamp(42)))

// DecimalType should be converted to string
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
Expand All @@ -530,9 +627,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

test("correctly read long as date/timestamp type") {
withTempPath { tempDir =>
val sparkSession = spark
import sparkSession.implicits._

val currentTime = new Timestamp(System.currentTimeMillis())
val currentDate = new Date(System.currentTimeMillis())
val schema = StructType(Seq(
Expand Down Expand Up @@ -560,9 +654,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

test("does not coerce null date/timestamp value to 0 epoch.") {
withTempPath { tempDir =>
val sparkSession = spark
import sparkSession.implicits._

val nullTime: Timestamp = null
val nullDate: Date = null
val schema = StructType(Seq(
Expand Down Expand Up @@ -768,8 +859,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

test("read avro file partitioned") {
withTempPath { dir =>
val sparkSession = spark
import sparkSession.implicits._
val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
val outputDir = s"$dir/${UUID.randomUUID}"
df.write.format("avro").save(outputDir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,21 @@ object SQLConf {
.intConf
.createWithDefault(20)

object AvroOutputTimestampType extends Enumeration {
val TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value
}

val AVRO_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.avro.outputTimestampType")
.doc("Sets which Avro timestamp type to use when Spark writes data to Avro files. " +
"TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of " +
"microseconds from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with " +
"millisecond precision, which means Spark has to truncate the microsecond portion of its " +
"timestamp value.")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.checkValues(AvroOutputTimestampType.values.map(_.toString))
.createWithDefault(AvroOutputTimestampType.TIMESTAMP_MICROS.toString)

val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec")
.doc("Compression codec used in writing of AVRO files. Default codec is snappy.")
.stringConf
Expand Down Expand Up @@ -1835,6 +1850,9 @@ class SQLConf extends Serializable with Logging {

def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE)

def avroOutputTimestampType: AvroOutputTimestampType.Value =
AvroOutputTimestampType.withName(getConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE))

def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC)

def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL)
Expand Down