-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25374][SQL] SafeProjection supports fallback to an interpreted mode #22468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
||
| /** | ||
| * An interpreted version of a safe projection. | ||
| * | ||
| * @param expressions that produces the resulting fields. These expressions must be bound | ||
| * to a schema. | ||
| */ | ||
| class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { | ||
|
|
||
| private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) | ||
|
|
||
| private[this] val exprsWithWriters = expressions.zipWithIndex.filter { | ||
| case (NoOp, _) => false | ||
| case _ => true | ||
| }.map { case (e, i) => | ||
| val converter = generateSafeValueConverter(e.dataType) | ||
| val writer = InternalRow.getWriter(i, e.dataType) | ||
| val f = if (!e.nullable) { | ||
| (v: Any) => writer(mutableRow, converter(v)) | ||
| } else { | ||
| (v: Any) => { | ||
| if (v == null) { | ||
| mutableRow.setNullAt(i) | ||
| } else { | ||
| writer(mutableRow, converter(v)) | ||
| } | ||
| } | ||
| } | ||
| (e, f) | ||
| } | ||
|
|
||
| private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { | ||
| case ArrayType(elemType, _) => | ||
| val elementConverter = generateSafeValueConverter(elemType) | ||
| v => { | ||
| val arrayValue = v.asInstanceOf[ArrayData] | ||
| val result = new Array[Any](arrayValue.numElements()) | ||
| arrayValue.foreach(elemType, (i, e) => { | ||
| result(i) = elementConverter(e) | ||
| }) | ||
| new GenericArrayData(result) | ||
| } | ||
|
|
||
| case st: StructType => | ||
| val fieldTypes = st.fields.map(_.dataType) | ||
| val fieldConverters = fieldTypes.map(generateSafeValueConverter) | ||
| v => { | ||
| val row = v.asInstanceOf[InternalRow] | ||
| val ar = new Array[Any](row.numFields) | ||
| var idx = 0 | ||
| while (idx < row.numFields) { | ||
| ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx))) | ||
| idx += 1 | ||
| } | ||
| new GenericInternalRow(ar) | ||
| } | ||
|
|
||
| case MapType(keyType, valueType, _) => | ||
| lazy val keyConverter = generateSafeValueConverter(keyType) | ||
| lazy val valueConverter = generateSafeValueConverter(valueType) | ||
| v => { | ||
| val mapValue = v.asInstanceOf[MapData] | ||
| val keys = mapValue.keyArray().toArray[Any](keyType) | ||
| val values = mapValue.valueArray().toArray[Any](valueType) | ||
| val convertedKeys = keys.map(keyConverter) | ||
| val convertedValues = values.map(valueConverter) | ||
| ArrayBasedMapData(convertedKeys, convertedValues) | ||
| } | ||
|
|
||
| case udt: UserDefinedType[_] => | ||
| generateSafeValueConverter(udt.sqlType) | ||
|
|
||
| case _ => identity | ||
| } | ||
|
|
||
| override def apply(row: InternalRow): InternalRow = { | ||
| var i = 0 | ||
| while (i < exprsWithWriters.length) { | ||
| val (expr, writer) = exprsWithWriters(i) | ||
| writer(expr.eval(row)) | ||
| i += 1 | ||
| } | ||
| mutableRow | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Helper functions for creating an [[InterpretedSafeProjection]]. | ||
| */ | ||
| object InterpretedSafeProjection { | ||
|
|
||
| /** | ||
| * Returns an [[SafeProjection]] for given sequence of bound Expressions. | ||
| */ | ||
| def createProjection(exprs: Seq[Expression]): Projection = { | ||
| // We need to make sure that we do not reuse stateful expressions. | ||
| val cleanedExpressions = exprs.map(_.transform { | ||
| case s: Stateful => s.freshCopy() | ||
| }) | ||
| new InterpretedSafeProjection(cleanedExpressions) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -166,29 +166,40 @@ object UnsafeProjection | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * A projection that could turn UnsafeRow into GenericInternalRow | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we keep this comment? |
||
| */ | ||
| object FromUnsafeProjection { | ||
| object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] { | ||
|
|
||
| override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = { | ||
| GenerateSafeProjection.generate(in) | ||
| } | ||
|
|
||
| override protected def createInterpretedObject(in: Seq[Expression]): Projection = { | ||
| InterpretedSafeProjection.createProjection(in) | ||
| } | ||
|
|
||
| /** | ||
| * Returns a Projection for given StructType. | ||
| * Returns a SafeProjection for given StructType. | ||
| */ | ||
| def apply(schema: StructType): Projection = { | ||
| apply(schema.fields.map(_.dataType)) | ||
| def create(schema: StructType): Projection = create(schema.fields.map(_.dataType)) | ||
|
|
||
| /** | ||
| * Returns a SafeProjection for given Array of DataTypes. | ||
| */ | ||
| def create(fields: Array[DataType]): Projection = { | ||
| createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) | ||
| } | ||
|
|
||
| /** | ||
| * Returns an UnsafeProjection for given Array of DataTypes. | ||
| * Returns a SafeProjection for given sequence of Expressions (bounded). | ||
| */ | ||
| def apply(fields: Seq[DataType]): Projection = { | ||
| create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) | ||
| def create(exprs: Seq[Expression]): Projection = { | ||
| createObject(exprs) | ||
| } | ||
|
|
||
| /** | ||
| * Returns a Projection for given sequence of Expressions (bounded). | ||
| * Returns a SafeProjection for given sequence of Expressions, which will be bound to | ||
| * `inputSchema`. | ||
| */ | ||
| private def create(exprs: Seq[Expression]): Projection = { | ||
| GenerateSafeProjection.generate(exprs) | ||
| def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { | ||
| create(toBoundExprs(exprs, inputSchema)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,4 +106,19 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT | |
| assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-25374 Correctly handles NoOp in SafeProjection") { | ||
| val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp) | ||
| val input = InternalRow.fromSeq(1 :: 1 :: Nil) | ||
| val expected = 2 :: null :: Nil | ||
| withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, this is the code style in this test suite |
||
| val proj = SafeProjection.createObject(exprs) | ||
| assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) | ||
| } | ||
|
|
||
| withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { | ||
| val proj = SafeProjection.createObject(exprs) | ||
| assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase | |
| import org.apache.spark.sql.catalyst.util._ | ||
| import org.apache.spark.sql.types.{IntegerType, LongType, _} | ||
| import org.apache.spark.unsafe.array.ByteArrayMethods | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | ||
|
|
||
| class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase | ||
| with ExpressionEvalHelper { | ||
|
|
@@ -535,4 +535,100 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB | |
| assert(unsafeRow.getSizeInBytes == | ||
| 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) | ||
| } | ||
|
|
||
| testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") { | ||
| def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = { | ||
| val unsafeProj = UnsafeProjection.create(fields) | ||
| val unsafeRow = unsafeProj(inputRow) | ||
| val safeProj = SafeProjection.create(fields) | ||
| safeProj(unsafeRow) | ||
| } | ||
|
|
||
| // Simple tests | ||
| val inputRow = InternalRow.fromSeq(Seq( | ||
| false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), | ||
| Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) | ||
| )) | ||
| val fields1 = Array( | ||
| BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, | ||
| DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType, | ||
| BinaryType) | ||
|
|
||
| assert(convertBackToInternalRow(inputRow, fields1) === inputRow) | ||
|
|
||
| // Array tests | ||
| val arrayRow = InternalRow.fromSeq(Seq( | ||
| createArray(1, 2, 3), | ||
| createArray( | ||
| createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*), | ||
| createArray(Seq("d").map(UTF8String.fromString): _*)) | ||
| )) | ||
| val fields2 = Array[DataType]( | ||
| ArrayType(IntegerType), | ||
| ArrayType(ArrayType(StringType))) | ||
|
|
||
| assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow) | ||
|
|
||
| // Struct tests | ||
| val structRow = InternalRow.fromSeq(Seq( | ||
| InternalRow.fromSeq(Seq[Any](1, 4.0)), | ||
| InternalRow.fromSeq(Seq( | ||
| UTF8String.fromString("test"), | ||
| InternalRow.fromSeq(Seq( | ||
| 1, | ||
| createArray(Seq("2", "3").map(UTF8String.fromString): _*) | ||
| )) | ||
| )) | ||
| )) | ||
| val fields3 = Array[DataType]( | ||
| StructType( | ||
| StructField("c0", IntegerType) :: | ||
| StructField("c1", DoubleType) :: | ||
| Nil), | ||
| StructType( | ||
| StructField("c2", StringType) :: | ||
| StructField("c3", StructType( | ||
| StructField("c4", IntegerType) :: | ||
| StructField("c5", ArrayType(StringType)) :: | ||
| Nil)) :: | ||
| Nil)) | ||
|
|
||
| assert(convertBackToInternalRow(structRow, fields3) === structRow) | ||
|
|
||
| // Map tests | ||
| val mapRow = InternalRow.fromSeq(Seq( | ||
| createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2), | ||
| createMap( | ||
| createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*), | ||
| createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*) | ||
| )( | ||
| createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort), | ||
| createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort) | ||
| ))) | ||
| val fields4 = Array[DataType]( | ||
| MapType(StringType, IntegerType), | ||
| MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) | ||
|
|
||
| // Since `ArrayBasedMapData` does not override `equals` and `hashCode`, | ||
|
||
| // we need to take care of it to compare rows. | ||
| def toComparable(d: Any): Any = d match { | ||
|
||
| case ar: GenericArrayData => | ||
| ar.array.map(toComparable).toSeq | ||
| case map: ArrayBasedMapData => | ||
| val keys = map.keyArray.array.map(toComparable).toSeq | ||
| val values = map.valueArray.array.map(toComparable).toSeq | ||
| (keys, values) | ||
| case o => o | ||
| } | ||
| val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) | ||
| val mapExpectedRow = mapRow.toSeq(fields4) | ||
| assert(mapResultRow.map(toComparable) === mapExpectedRow.map(toComparable)) | ||
|
|
||
| // UDT tests | ||
| val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) | ||
| val udt = new TestUDT.MyDenseVectorUDT() | ||
| val udtRow = InternalRow.fromSeq(Seq(udt.serialize(vector))) | ||
| val fields5 = Array[DataType](udt) | ||
| assert(convertBackToInternalRow(udtRow, fields5) === udtRow) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does
SafeProjectionneed to handleNoOp? It's only used withMutableProjectionin aggregate.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC the input expressions in
UnsafeProjectionpossibly haveNoOps passed from aggregate expressions? So, IIUCGenerateSafeProjectionhandlesNoOps here:spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
Line 153 in 3b45567
I'm not 100% sure though...