@@ -20,13 +20,14 @@ package execution
2020
2121import scala .reflect .runtime .universe .TypeTag
2222
23- import org .apache .spark .rdd . RDD
24- import org .apache .spark .SparkContext
25-
23+ import org .apache .spark .{ HashPartitioner , SparkConf , SparkContext }
24+ import org .apache .spark .rdd .{ RDD , ShuffledRDD }
25+ import org . apache . spark . sql . catalyst . ScalaReflection
2626import org .apache .spark .sql .catalyst .errors ._
2727import org .apache .spark .sql .catalyst .expressions ._
2828import org .apache .spark .sql .catalyst .plans .physical .{OrderedDistribution , UnspecifiedDistribution }
29- import org .apache .spark .sql .catalyst .ScalaReflection
29+ import org .apache .spark .util .MutablePair
30+
3031
3132case class Project (projectList : Seq [NamedExpression ], child : SparkPlan ) extends UnaryNode {
3233 override def output = projectList.map(_.toAttribute)
@@ -70,17 +71,24 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
7071 * data to a single partition to compute the global limit.
7172 */
7273case class Limit (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
74+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
75+ // partition local limit -> exchange into one partition -> partition local limit again
76+
7377 override def otherCopyArgs = sc :: Nil
7478
7579 override def output = child.output
7680
7781 override def executeCollect () = child.execute().map(_.copy()).take(limit)
7882
7983 override def execute () = {
80- child.execute()
81- .mapPartitions(_.take(limit).map(_.copy()))
82- .coalesce(1 , shuffle = true )
83- .mapPartitions(_.take(limit))
84+ val rdd = child.execute().mapPartitions { iter =>
85+ val mutablePair = new MutablePair [Boolean , Row ]()
86+ iter.take(limit).map(row => mutablePair.update(false , row))
87+ }
88+ val part = new HashPartitioner (1 )
89+ val shuffled = new ShuffledRDD [Boolean , Row , MutablePair [Boolean , Row ]](rdd, part)
90+ shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
91+ shuffled.mapPartitions(_.take(limit).map(_._2))
8492 }
8593}
8694
0 commit comments