Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,59 @@

package org.apache.spark.sql.hudi.command.payload

import com.google.common.cache.CacheBuilder
import org.apache.avro.Schema
import org.apache.avro.generic.IndexedRecord
import org.apache.hudi.{AvroConversionUtils, SparkAdapterSupport}
import org.apache.hudi.HoodieSparkUtils.sparkAdapter
import org.apache.hudi.AvroConversionUtils
import org.apache.spark.sql.avro.HoodieAvroDeserializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.hudi.command.payload.SqlTypedRecord.{getAvroDeserializer, getSqlType}
import org.apache.spark.sql.types.StructType

import java.util.concurrent.Callable

/**
* A sql typed record which will convert the avro field to sql typed value.
*/
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord with SparkAdapterSupport {
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {

private lazy val sqlType = AvroConversionUtils.convertAvroSchemaToStructType(getSchema)
private lazy val avroDeserializer = sparkAdapter.createAvroDeserializer(record.getSchema, sqlType)
private lazy val sqlRow = avroDeserializer.deserialize(record).get.asInstanceOf[InternalRow]
private lazy val sqlRow = getAvroDeserializer(getSchema).deserialize(record).get.asInstanceOf[InternalRow]

override def put(i: Int, v: Any): Unit = {
record.put(i, v)
}

override def get(i: Int): AnyRef = {
sqlRow.get(i, sqlType(i).dataType)
sqlRow.get(i, getSqlType(getSchema)(i).dataType)
}

override def getSchema: Schema = record.getSchema
}

object SqlTypedRecord {

private val sqlTypeCache = CacheBuilder.newBuilder().build[Schema, StructType]()

private val avroDeserializerCache = CacheBuilder.newBuilder().build[Schema, HoodieAvroDeserializer]()

def getSqlType(schema: Schema): StructType = {
sqlTypeCache.get(schema, new Callable[StructType] {
override def call(): StructType = {
val structType = AvroConversionUtils.convertAvroSchemaToStructType(schema)
sqlTypeCache.put(schema, structType)
structType
}
})
}

def getAvroDeserializer(schema: Schema): HoodieAvroDeserializer= {
avroDeserializerCache.get(schema, new Callable[HoodieAvroDeserializer] {
override def call(): HoodieAvroDeserializer = {
val deserializer = sparkAdapter.createAvroDeserializer(schema, getSqlType(schema))
avroDeserializerCache.put(schema, deserializer)
deserializer
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.{getEvaluator, getMergedSchema}
import org.apache.spark.sql.types.{StructField, StructType}

import java.util.concurrent.Callable
Expand Down Expand Up @@ -228,9 +228,7 @@ class ExpressionPayload(record: GenericRecord,
*/
private def joinRecord(sourceRecord: IndexedRecord, targetRecord: IndexedRecord): IndexedRecord = {
val leftSchema = sourceRecord.getSchema
// the targetRecord is load from the disk, it contains the meta fields, so we remove it here
val rightSchema = HoodieAvroUtils.removeMetadataFields(targetRecord.getSchema)
val joinSchema = mergeSchema(leftSchema, rightSchema)
val joinSchema = getMergedSchema(leftSchema, targetRecord.getSchema)

val values = new ArrayBuffer[AnyRef]()
for (i <- 0 until joinSchema.getFields.size()) {
Expand All @@ -244,17 +242,6 @@ class ExpressionPayload(record: GenericRecord,
convertToRecord(values.toArray, joinSchema)
}

private def mergeSchema(a: Schema, b: Schema): Schema = {
val mergedFields =
a.getFields.asScala.map(field =>
new Schema.Field("a_" + field.name,
field.schema, field.doc, field.defaultVal, field.order)) ++
b.getFields.asScala.map(field =>
new Schema.Field("b_" + field.name,
field.schema, field.doc, field.defaultVal, field.order))
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}

private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): GenericRecord = {
try evaluator.eval(sqlTypedRecord) catch {
case e: Throwable =>
Expand Down Expand Up @@ -318,5 +305,30 @@ object ExpressionPayload {
}
})
}

private val mergedSchemaCache = CacheBuilder.newBuilder().build[TupleSchema, Schema]()

def getMergedSchema(source: Schema, target: Schema): Schema = {

mergedSchemaCache.get(TupleSchema(source, target), new Callable[Schema] {
override def call(): Schema = {
val rightSchema = HoodieAvroUtils.removeMetadataFields(target)
mergeSchema(source, rightSchema)
}
})
}

def mergeSchema(a: Schema, b: Schema): Schema = {
val mergedFields =
a.getFields.asScala.map(field =>
new Schema.Field("a_" + field.name,
field.schema, field.doc, field.defaultVal, field.order)) ++
b.getFields.asScala.map(field =>
new Schema.Field("b_" + field.name,
field.schema, field.doc, field.defaultVal, field.order))
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}

case class TupleSchema(first: Schema, second: Schema)
}