diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 981ecae80a72..1ab183fe843f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -69,9 +69,8 @@ case class ExternalRDDScanExec[T]( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val outputDataType = outputObjAttr.dataType rdd.mapPartitionsInternal { iter => - val outputObject = ObjectOperator.wrapObjectToRow(outputDataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) iter.map { value => numOutputRows += 1 outputObject(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index a3be473de7d5..d05113431df4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -191,7 +191,7 @@ case class MapPartitionsExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) func(iter.map(getObject)).map(outputObject) } } @@ -278,10 +278,10 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) + case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) + val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) @@ -296,7 +296,7 @@ case class MapElementsExec( child.execute().mapPartitionsInternal { iter => val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) iter.map(row => outputObject(callFunc(getObject(row)))) } } @@ -395,7 +395,7 @@ case class MapGroupsExec( val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) grouped.flatMap { case (key, rowIter) => val result = func( @@ -447,12 +447,8 @@ case class FlatMapGroupsInRExec( outputObjAttr: Attribute, child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def outputPartitioning: Partitioning = child.outputPartitioning - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - override def requiredChildDistribution: Seq[Distribution] = if (groupingAttributes.isEmpty) { AllTuples :: Nil @@ -475,7 +471,7 @@ case class FlatMapGroupsInRExec( val grouped = GroupedIterator(iter, groupingAttributes, child.output) val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]]( func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars, isDataFrame = true, colNames = inputSchema.fieldNames, @@ -608,7 +604,7 @@ case class CoGroupExec( val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup) val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr) val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 6b6eb78404e3..fe91d2491222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.util.CompletionIterator /** - * Physical operator for executing `FlatMapGroupsWithState.` + * Physical operator for executing `FlatMapGroupsWithState` * * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. * @param groupingAttributes used to group the data * @param dataAttributes used to read the data - * @param outputObjAttr used to define the output object + * @param outputObjAttr Defines the output object * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` * @param timeoutConf used to timeout groups that have not received data in a while @@ -154,7 +154,7 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) private val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows")