From 139aefa4604c2cbaacae884499f321be0e1324c8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 22 May 2018 08:31:26 +0900 Subject: [PATCH 1/3] [SPARK-24336][SQL] Support 'pass through' transformation in BasicOperators --- .../spark/sql/catalyst/ScalaReflection.scala | 27 +++++ .../spark/sql/execution/SparkStrategies.scala | 106 +++++++++++++----- 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715..adcfc1b1ec64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -703,6 +703,33 @@ object ScalaReflection extends ScalaReflection { */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) + def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects { + tpe.dealias match { + case ty if ty <:< localTypeOf[Array[_]] => + def arrayClassFromType(tpe: `Type`): Class[_] = + ScalaReflection.cleanUpReflectionObjects { + tpe.dealias match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case _ => + // There is probably a better way to do this, but I couldn't find it... + val elementType = getClassFromTypeHandleArray(tpe) + java.lang.reflect.Array.newInstance(elementType, 1).getClass + } + } + + val TypeRef(_, _, Seq(elementType)) = ty + arrayClassFromType(elementType) + + case ty => getClassFromType(ty) + } + } + case class Schema(dataType: DataType, nullable: Boolean) /** Returns a Sequence of attributes for the given case class type. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba24..3873da5652ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.lang.reflect.Constructor + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -474,8 +477,80 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { + + import universe._ + + // Enumerate the pair of logical plan and physical plan which can be transformed via + // 'pass-through', which can be achieved when the difference between parameters on + // primary constructor in both plans is just LogicalPlan vs SparkPlan. + // The map should exclude the pair which 'pass-through' needs to leverage default value of + // constructor parameter. + val passThroughOperators: Map[Class[_ <: LogicalPlan], Class[_ <: SparkPlan]] = Map( + (classOf[logical.DeserializeToObject], classOf[execution.DeserializeToObjectExec]), + (classOf[logical.SerializeFromObject], classOf[execution.SerializeFromObjectExec]), + (classOf[logical.MapPartitions], classOf[execution.MapPartitionsExec]), + (classOf[logical.FlatMapGroupsInR], classOf[execution.FlatMapGroupsInRExec]), + (classOf[logical.FlatMapGroupsInPandas], classOf[execution.python.FlatMapGroupsInPandasExec]), + (classOf[logical.AppendColumnsWithObject], classOf[execution.AppendColumnsWithObjectExec]), + (classOf[logical.MapGroups], classOf[execution.MapGroupsExec]), + (classOf[logical.CoGroup], classOf[execution.CoGroupExec]), + (classOf[logical.Project], classOf[execution.ProjectExec]), + (classOf[logical.Filter], classOf[execution.FilterExec]), + (classOf[logical.Window], classOf[execution.window.WindowExec]), + (classOf[logical.Sample], classOf[execution.SampleExec]) + ) + + lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] = + passThroughOperators.map { + case (srcOpCls, _) => + (srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls)) + }.toMap + + lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] = + passThroughOperators.map { + case (srcOpCls, tgtOpCls) => + val logicalPlanCls = classOf[LogicalPlan] + val m = runtimeMirror(logicalPlanCls.getClassLoader) + val classSymbol = m.staticClass(logicalPlanCls.getName) + val logicalPlanType = classSymbol.selfType + + val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2) + val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects { + paramTypes.map { + case ty if ty <:< logicalPlanType => + m.staticClass(classOf[SparkPlan].getName).selfType + + case ty => ty + } + } + + val convertedParamClasses = convertedParamTypes.map( + ScalaReflection.getClassFromTypeHandleArray) + val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses) + + constructorOption match { + case Some(const: Constructor[_]) => (srcOpCls, const) + case _ => throw new IllegalStateException( + s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!") + } + }.toMap + + def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = { + val srcClass = src.getClass + require(passThroughOperators.contains(srcClass)) + val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name => + srcClass.getMethod(name).invoke(src) + } + val convertedParamValues = paramValues.map { + case p if p.isInstanceOf[LogicalPlan] => planLater(p.asInstanceOf[LogicalPlan]) + case p => p + } + + val const = operatorToTargetConstructor(srcClass) + const.newInstance(convertedParamValues: _*).asInstanceOf[SparkPlan] + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil case r: RunnableCommand => ExecutedCommandExec(r) :: Nil @@ -497,36 +572,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") - case logical.DeserializeToObject(deserializer, objAttr, child) => - execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil - case logical.SerializeFromObject(serializer, child) => - execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil - case logical.MapPartitions(f, objAttr, child) => - execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil + case src if passThroughOperators.contains(src.getClass) => + createPassThroughOutputPlan(src) :: Nil + case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => execution.MapPartitionsExec( execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.FlatMapGroupsInPandas(grouping, func, output, child) => - execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil - case logical.AppendColumnsWithObject(f, childSer, newSer, child) => - execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil - case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => - execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( f, key, value, grouping, data, output, _, _, _, timeout, child) => execution.MapGroupsExec( f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil - case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => - execution.CoGroupExec( - f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, - planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { @@ -536,18 +598,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Sort(sortExprs, global, child) => execution.SortExec(sortExprs, global, planLater(child)) :: Nil - case logical.Project(projectList, child) => - execution.ProjectExec(projectList, planLater(child)) :: Nil - case logical.Filter(condition, child) => - execution.FilterExec(condition, planLater(child)) :: Nil case f: logical.TypedFilter => execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil - case logical.Sample(lb, ub, withReplacement, seed, child) => - execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => From 971abb6caf79b9e24c063f6e4bd3d0a6aefe133e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 22 May 2018 08:51:20 +0900 Subject: [PATCH 2/3] remove case which is already defined in 'pass through' --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3873da5652ad..4a9e9caad7f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -578,9 +578,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => execution.MapPartitionsExec( execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil - case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => - execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, - data, objAttr, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => From 6e6c375a1c8846e9421d873dd27a880698464c8f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 22 May 2018 16:48:55 +0900 Subject: [PATCH 3/3] Respect style guide, move nested function to private method --- .../spark/sql/catalyst/ScalaReflection.scala | 34 ++++++------ .../spark/sql/execution/SparkStrategies.scala | 52 +++++++++---------- 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index adcfc1b1ec64..8f932a80098b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -706,23 +706,6 @@ object ScalaReflection extends ScalaReflection { def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects { tpe.dealias match { case ty if ty <:< localTypeOf[Array[_]] => - def arrayClassFromType(tpe: `Type`): Class[_] = - ScalaReflection.cleanUpReflectionObjects { - tpe.dealias match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case _ => - // There is probably a better way to do this, but I couldn't find it... - val elementType = getClassFromTypeHandleArray(tpe) - java.lang.reflect.Array.newInstance(elementType, 1).getClass - } - } - val TypeRef(_, _, Seq(elementType)) = ty arrayClassFromType(elementType) @@ -730,6 +713,23 @@ object ScalaReflection extends ScalaReflection { } } + private def arrayClassFromType(tpe: Type): Class[_] = + ScalaReflection.cleanUpReflectionObjects { + tpe.dealias match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case _ => + // There is probably a better way to do this, but I couldn't find it... + val elementType = getClassFromTypeHandleArray(tpe) + java.lang.reflect.Array.newInstance(elementType, 1).getClass + } + } + case class Schema(dataType: DataType, nullable: Boolean) /** Returns a Sequence of attributes for the given case class type. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4a9e9caad7f2..dc3c5dfe8f26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -502,39 +502,37 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ) lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] = - passThroughOperators.map { - case (srcOpCls, _) => - (srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls)) + passThroughOperators.map { case (srcOpCls, _) => + (srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls)) }.toMap lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] = - passThroughOperators.map { - case (srcOpCls, tgtOpCls) => - val logicalPlanCls = classOf[LogicalPlan] - val m = runtimeMirror(logicalPlanCls.getClassLoader) - val classSymbol = m.staticClass(logicalPlanCls.getName) - val logicalPlanType = classSymbol.selfType - - val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2) - val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects { - paramTypes.map { - case ty if ty <:< logicalPlanType => - m.staticClass(classOf[SparkPlan].getName).selfType - - case ty => ty - } + passThroughOperators.map { case (srcOpCls, tgtOpCls) => + val logicalPlanCls = classOf[LogicalPlan] + val m = runtimeMirror(logicalPlanCls.getClassLoader) + val classSymbol = m.staticClass(logicalPlanCls.getName) + val logicalPlanType = classSymbol.selfType + + val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2) + val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects { + paramTypes.map { + case ty if ty <:< logicalPlanType => + m.staticClass(classOf[SparkPlan].getName).selfType + + case ty => ty } + } - val convertedParamClasses = convertedParamTypes.map( - ScalaReflection.getClassFromTypeHandleArray) - val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses) + val convertedParamClasses = convertedParamTypes.map( + ScalaReflection.getClassFromTypeHandleArray) + val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses) - constructorOption match { - case Some(const: Constructor[_]) => (srcOpCls, const) - case _ => throw new IllegalStateException( - s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!") - } - }.toMap + constructorOption match { + case Some(const: Constructor[_]) => (srcOpCls, const) + case _ => throw new IllegalStateException( + s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!") + } + }.toMap def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = { val srcClass = src.getClass