Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ case class ExternalRDDScanExec[T](

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val outputDataType = outputObjAttr.dataType
rdd.mapPartitionsInternal { iter =>
val outputObject = ObjectOperator.wrapObjectToRow(outputDataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
iter.map { value =>
numOutputRows += 1
outputObject(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ case class MapPartitionsExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
func(iter.map(getObject)).map(outputObject)
}
}
Expand Down Expand Up @@ -278,10 +278,10 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output)

val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx)

Expand All @@ -296,7 +296,7 @@ case class MapElementsExec(

child.execute().mapPartitionsInternal { iter =>
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
iter.map(row => outputObject(callFunc(getObject(row))))
}
}
Expand Down Expand Up @@ -395,7 +395,7 @@ case class MapGroupsExec(

val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)

grouped.flatMap { case (key, rowIter) =>
val result = func(
Expand Down Expand Up @@ -447,12 +447,8 @@ case class FlatMapGroupsInRExec(
outputObjAttr: Attribute,
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {

override def output: Seq[Attribute] = outputObjAttr :: Nil

override def outputPartitioning: Partitioning = child.outputPartitioning

override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)

override def requiredChildDistribution: Seq[Distribution] =
if (groupingAttributes.isEmpty) {
AllTuples :: Nil
Expand All @@ -475,7 +471,7 @@ case class FlatMapGroupsInRExec(
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]](
func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
isDataFrame = true, colNames = inputSchema.fieldNames,
Expand Down Expand Up @@ -608,7 +604,7 @@ case class CoGroupExec(
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)

new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.util.CompletionIterator

/**
* Physical operator for executing `FlatMapGroupsWithState.`
* Physical operator for executing `FlatMapGroupsWithState`
*
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr used to define the output object
* @param outputObjAttr Defines the output object
* @param stateEncoder used to serialize/deserialize state before calling `func`
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
Expand Down Expand Up @@ -154,7 +154,7 @@ case class FlatMapGroupsWithStateExec(
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
private val getValueObj =
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)

// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
Expand Down