Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ case class ExpressionEncoder[T](
private lazy val inputRow = new GenericInternalRow(1)

@transient
private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
private lazy val constructProjection = SafeProjection.create(deserializer :: Nil)

/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._


/**
* An interpreted version of a safe projection.
*
* @param expressions that produces the resulting fields. These expressions must be bound
* to a schema.
*/
class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection {

private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType))

private[this] val exprsWithWriters = expressions.zipWithIndex.filter {
case (NoOp, _) => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does SafeProjection need to handle NoOp? It's only used with MutableProjection in aggregate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the input expressions in UnsafeProjection possibly have NoOps passed from aggregate expressions? So, IIUC GenerateSafeProjection handles NoOps here:


I'm not 100% sure though...

case _ => true
}.map { case (e, i) =>
val converter = generateSafeValueConverter(e.dataType)
val writer = InternalRow.getWriter(i, e.dataType)
val f = if (!e.nullable) {
(v: Any) => writer(mutableRow, converter(v))
} else {
(v: Any) => {
if (v == null) {
mutableRow.setNullAt(i)
} else {
writer(mutableRow, converter(v))
}
}
}
(e, f)
}

private def generateSafeValueConverter(dt: DataType): Any => Any = dt match {
case ArrayType(elemType, _) =>
val elementConverter = generateSafeValueConverter(elemType)
v => {
val arrayValue = v.asInstanceOf[ArrayData]
val result = new Array[Any](arrayValue.numElements())
arrayValue.foreach(elemType, (i, e) => {
result(i) = elementConverter(e)
})
new GenericArrayData(result)
}

case st: StructType =>
val fieldTypes = st.fields.map(_.dataType)
val fieldConverters = fieldTypes.map(generateSafeValueConverter)
v => {
val row = v.asInstanceOf[InternalRow]
val ar = new Array[Any](row.numFields)
var idx = 0
while (idx < row.numFields) {
ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx)))
idx += 1
}
new GenericInternalRow(ar)
}

case MapType(keyType, valueType, _) =>
lazy val keyConverter = generateSafeValueConverter(keyType)
lazy val valueConverter = generateSafeValueConverter(valueType)
v => {
val mapValue = v.asInstanceOf[MapData]
val keys = mapValue.keyArray().toArray[Any](keyType)
val values = mapValue.valueArray().toArray[Any](valueType)
val convertedKeys = keys.map(keyConverter)
val convertedValues = values.map(valueConverter)
ArrayBasedMapData(convertedKeys, convertedValues)
}

case udt: UserDefinedType[_] =>
generateSafeValueConverter(udt.sqlType)

case _ => identity
}

override def apply(row: InternalRow): InternalRow = {
var i = 0
while (i < exprsWithWriters.length) {
val (expr, writer) = exprsWithWriters(i)
writer(expr.eval(row))
i += 1
}
mutableRow
}
}

/**
* Helper functions for creating an [[InterpretedSafeProjection]].
*/
object InterpretedSafeProjection {

/**
* Returns an [[SafeProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): Projection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
})
new InterpretedSafeProjection(cleanedExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,40 @@ object UnsafeProjection
/**
* A projection that could turn UnsafeRow into GenericInternalRow
*/
object FromUnsafeProjection {
object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] {

override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = {
GenerateSafeProjection.generate(in)
}

override protected def createInterpretedObject(in: Seq[Expression]): Projection = {
InterpretedSafeProjection.createProjection(in)
}

/**
* Returns a Projection for given StructType.
* Returns a SafeProjection for given StructType.
*/
def apply(schema: StructType): Projection = {
apply(schema.fields.map(_.dataType))
def create(schema: StructType): Projection = create(schema.fields.map(_.dataType))

/**
* Returns a SafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): Projection = {
createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
}

/**
* Returns an UnsafeProjection for given Array of DataTypes.
* Returns a SafeProjection for given sequence of Expressions (bounded).
*/
def apply(fields: Seq[DataType]): Projection = {
create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
def create(exprs: Seq[Expression]): Projection = {
createObject(exprs)
}

/**
* Returns a Projection for given sequence of Expressions (bounded).
* Returns a SafeProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
*/
private def create(exprs: Seq[Expression]): Projection = {
GenerateSafeProjection.generate(exprs)
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
create(toBoundExprs(exprs, inputSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
UTF8String.fromString("c"))
assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3)

val fromUnsafe = FromUnsafeProjection(schema)
val fromUnsafe = SafeProjection.create(schema)
val internalRow2 = fromUnsafe(unsafeRow)
assert(internalRow === internalRow2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,19 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT
assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected)
}
}

test("SPARK-25374 Correctly handles NoOp in SafeProjection") {
val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp)
val input = InternalRow.fromSeq(1 :: 1 :: Nil)
val expected = 2 :: null :: Nil
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use testWithBothCodegenAndIntepreted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, this is the code style in this test suite

val proj = SafeProjection.createObject(exprs)
assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected)
}

withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
val proj = SafeProjection.createObject(exprs)
assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
expression)
plan.initialize(0)
actual = FromUnsafeProjection(expression.dataType :: Nil)(
plan(inputRow)).get(0, expression.dataType)
val ref = new BoundReference(0, expression.dataType, nullable = true)
actual = GenerateSafeProjection.generate(ref :: Nil)(plan(inputRow)).get(0, expression.dataType)
assert(checkResult(actual, expected, expression))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length)
val proj = createMutableProjection(fixedLengthTypes)
val projUnsafeRow = proj.target(unsafeBuffer)(inputRow)
assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow)
assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow)
}

testBothCodegenAndInterpreted("variable-length types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types.{IntegerType, LongType, _}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase
with ExpressionEvalHelper {
Expand Down Expand Up @@ -535,4 +535,91 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
}

testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") {
def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = {
val unsafeProj = UnsafeProjection.create(fields)
val unsafeRow = unsafeProj(inputRow)
val safeProj = SafeProjection.create(fields)
safeProj(unsafeRow)
}

// Simple tests
val inputRow = InternalRow.fromSeq(Seq(
false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"),
Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2)
))
val fields1 = Array(
BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType,
BinaryType)

assert(convertBackToInternalRow(inputRow, fields1) === inputRow)

// Array tests
val arrayRow = InternalRow.fromSeq(Seq(
createArray(1, 2, 3),
createArray(
createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*),
createArray(Seq("d").map(UTF8String.fromString): _*))
))
val fields2 = Array[DataType](
ArrayType(IntegerType),
ArrayType(ArrayType(StringType)))

assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow)

// Struct tests
val structRow = InternalRow.fromSeq(Seq(
InternalRow.fromSeq(Seq[Any](1, 4.0)),
InternalRow.fromSeq(Seq(
UTF8String.fromString("test"),
InternalRow.fromSeq(Seq(
1,
createArray(Seq("2", "3").map(UTF8String.fromString): _*)
))
))
))
val fields3 = Array[DataType](
StructType(
StructField("c0", IntegerType) ::
StructField("c1", DoubleType) ::
Nil),
StructType(
StructField("c2", StringType) ::
StructField("c3", StructType(
StructField("c4", IntegerType) ::
StructField("c5", ArrayType(StringType)) ::
Nil)) ::
Nil))

assert(convertBackToInternalRow(structRow, fields3) === structRow)

// Map tests
val mapRow = InternalRow.fromSeq(Seq(
createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2),
createMap(
createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*),
createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*)
)(
createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort),
createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort)
)))
val fields4 = Array[DataType](
MapType(StringType, IntegerType),
MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType)))

val mapResultRow = convertBackToInternalRow(mapRow, fields4)
val mapExpectedRow = mapRow
checkResult(mapExpectedRow, mapResultRow,
exprDataType = StructType(fields4.zipWithIndex.map(f => StructField(s"c${f._2}", f._1))),
exprNullable = false)

// UDT tests
val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0))
val udt = new TestUDT.MyDenseVectorUDT()
val udtRow = InternalRow.fromSeq(Seq(udt.serialize(vector)))
val fields5 = Array[DataType](udt)
assert(convertBackToInternalRow(udtRow, fields5) === udtRow)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, SafeProjection}

/**
* Evaluator for a [[DeclarativeAggregate]].
*/
case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) {

lazy val initializer = GenerateSafeProjection.generate(function.initialValues)
lazy val initializer = SafeProjection.create(function.initialValues)

lazy val updater = GenerateSafeProjection.generate(
lazy val updater = SafeProjection.create(
function.updateExpressions,
function.aggBufferAttributes ++ input)

lazy val merger = GenerateSafeProjection.generate(
lazy val merger = SafeProjection.create(
function.mergeExpressions,
function.aggBufferAttributes ++ function.inputAggBufferAttributes)

lazy val evaluator = GenerateSafeProjection.generate(
lazy val evaluator = SafeProjection.create(
function.evaluateExpression :: Nil,
function.aggBufferAttributes)

Expand Down
Loading