Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8
Expand Down Expand Up @@ -86,8 +87,14 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
case (LONG, LongType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, TimestampType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case (LONG, TimestampType) => avroType.getLogicalType match {
case _: TimestampMillis => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case _: TimestampMicros => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case _ => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment to say it's for backward compatibility reasons. Also we should only do it when logical type is null. For other logical types, we should fail here.

}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a default case and throw IncompatibleSchemaException, in case avro add more logical types for long type in the future.


case (LONG, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private[avro] class AvroFileFormat extends FileFormat
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf())
val outputAvroSchema = SchemaConverters.toAvroType(
dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace)
val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false,
parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType)

AvroJob.setOutputKeySchema(job, outputAvroSchema)
val COMPRESS_KEY = "mapred.output.compress"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,15 @@ class AvroOptions(
val compression: String = {
parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec)
}

/**
* The `outputTimestampType` option sets which Avro timestamp type to use when Spark writes
* data to Avro files. Currently supported types are `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS`.
* TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of microseconds
* from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with millisecond precision,
* which means Spark has to truncate the microsecond portion of its timestamp value.
*/
val outputTimestampType: String = {
parameters.get("outputTimestampType").getOrElse(SQLConf.get.avroOutputTimestampType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._

import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type.NULL
import org.apache.avro.generic.GenericData.Record
Expand Down Expand Up @@ -93,7 +94,11 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
case DateType =>
(getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY
case TimestampType =>
(getter, ordinal) => getter.getLong(ordinal) / 1000
(getter, ordinal) => avroType.getLogicalType match {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not do pattern match per record, we should

avroType.getLogicalType match {
  case _: TimestampMillis => (getter, ordinal) => ...

case _: TimestampMillis => getter.getLong(ordinal) / 1000
case _: TimestampMicros => getter.getLong(ordinal)
case _ => getter.getLong(ordinal)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}

case ArrayType(et, containsNull) =>
val elementConverter = newConverter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.avro

import scala.collection.JavaConverters._

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.types._
Expand All @@ -42,7 +43,11 @@ object SchemaConverters {
case BYTES => SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => SchemaType(LongType, nullable = false)
case LONG => avroSchema.getLogicalType match {
case _: TimestampMillis | _: TimestampMicros =>
return SchemaType(TimestampType, nullable = false)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use return here?

case _ => SchemaType(LongType, nullable = false)
}
case FIXED => SchemaType(BinaryType, nullable = false)
case ENUM => SchemaType(StringType, nullable = false)

Expand Down Expand Up @@ -103,31 +108,48 @@ object SchemaConverters {
catalystType: DataType,
nullable: Boolean = false,
recordName: String = "topLevelRecord",
prevNameSpace: String = ""): Schema = {
prevNameSpace: String = "",
outputTimestampType: String = "TIMESTAMP_MICROS"): Schema = {
val builder = if (nullable) {
SchemaBuilder.builder().nullable()
} else {
SchemaBuilder.builder()
}

catalystType match {
case BooleanType => builder.booleanType()
case ByteType | ShortType | IntegerType => builder.intType()
case LongType => builder.longType()
case DateType => builder.longType()
case TimestampType => builder.longType()
case TimestampType =>
val timestampType = outputTimestampType match {
case "TIMESTAMP_MILLIS" => LogicalTypes.timestampMillis()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't hardcode the strings, we can write

if (outputTimestampType == AvroOutputTimestampType.TIMESTAMP_MICROS.toString) ...

case "TIMESTAMP_MICROS" => LogicalTypes.timestampMicros()
case other =>
throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.")
}
if (nullable) {
val avroType = timestampType.addToSchema(SchemaBuilder.builder().longType())
builder.`type`(avroType)
} else {
timestampType.addToSchema(builder.longType())
}
case FloatType => builder.floatType()
case DoubleType => builder.doubleType()
case _: DecimalType | StringType => builder.stringType()
case BinaryType => builder.bytesType()
case ArrayType(et, containsNull) =>
builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace))
builder.array()
.items(toAvroType(et, containsNull, recordName, prevNameSpace, outputTimestampType))
case MapType(StringType, vt, valueContainsNull) =>
builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace))
builder.map()
.values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType))
case st: StructType =>
val nameSpace = s"$prevNameSpace.$recordName"
val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
st.foreach { f =>
val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace)
val fieldAvroType =
toAvroType(f.dataType, f.nullable, f.name, nameSpace, outputTimestampType)
fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
}
fieldsAssembler.endRecord()
Expand Down
Binary file added external/avro/src/test/resources/timestamp.avro
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StructType, _}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import looks a bit odd :-)


class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val episodesAvro = testFile("episodes.avro")
val testAvro = testFile("test.avro")
val timestampAvro = testFile("timestamp.avro")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least we should provide how the binary file is generated, or just do roundtrip test: Spark write avro files and then read it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema and data is stated in https://github.com/apache/spark/pull/21935/files#diff-9364b0610f92b3cc35a4bc43a80751bfR397
It should be easy to get from test cases.
The other test file episodesAvro also doesn't provide how it is generated.


override protected def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -331,6 +332,82 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}

test("Logical type: timestamp_millis") {
val sparkSession = spark
import sparkSession.implicits._

val expected =
Seq(1L, 666L).toDF("timestamp_millis").select('timestamp_millis.cast(TimestampType)).collect()
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: timestamp_micros") {
val sparkSession = spark
import sparkSession.implicits._

val expected =
Seq(2L, 999L).toDF("timestamp_micros").select('timestamp_micros.cast(TimestampType)).collect()
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: specify different output timestamp types") {
val sparkSession = spark
import sparkSession.implicits._

val df = spark.read.format("avro").load(timestampAvro)

val expected = Seq((1L, 2L), (666L, 999L))
.toDF("timestamp_millis", "timestamp_micros")
.select('timestamp_millis.cast(TimestampType), 'timestamp_micros.cast(TimestampType))
.collect()

Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType =>
withTempPath { dir =>
df.write.format("avro").option("outputTimestampType", timestampType).save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}
}

test("Logical type: user specified schema") {
val sparkSession = spark
import sparkSession.implicits._

val expected = Seq((1L, 2L), (666L, 999L))
.toDF("timestamp_millis", "timestamp_micros")
.select('timestamp_millis.cast(TimestampType), 'timestamp_micros.cast(TimestampType))
.collect()

val avroSchema = s"""
{
"namespace": "logical",
"type": "record",
"name": "test",
"fields": [
{"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}},
{"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}
]
}
"""
val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro)

checkAnswer(df, expected)
}

test("Array data types") {
withTempPath { dir =>
val testSchema = StructType(Seq(
Expand Down Expand Up @@ -511,7 +588,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

// TimesStamps are converted to longs
val times = spark.read.format("avro").load(avroDir).select("Time").collect()
assert(times.map(_(0)).toSet == Set(666, 777, 42))
assert(times.map(_(0)).toSet ==
Set(new Timestamp(666), new Timestamp(777), new Timestamp(42)))

// DecimalType should be converted to string
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,21 @@ object SQLConf {
.intConf
.createWithDefault(20)

object AvroOutputTimestampType extends Enumeration {
val TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value
}

val AVRO_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.avro.outputTimestampType")
.doc("Sets which Avro timestamp type to use when Spark writes data to Avro files. " +
"TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of " +
"microseconds from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with " +
"millisecond precision, which means Spark has to truncate the microsecond portion of its " +
"timestamp value.")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.checkValues(AvroOutputTimestampType.values.map(_.toString))
.createWithDefault(AvroOutputTimestampType.TIMESTAMP_MICROS.toString)

val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec")
.doc("Compression codec used in writing of AVRO files. Default codec is snappy.")
.stringConf
Expand Down Expand Up @@ -1835,6 +1850,8 @@ class SQLConf extends Serializable with Logging {

def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE)

def avroOutputTimestampType: String = getConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE)

def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC)

def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL)
Expand Down