Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ package object dsl {
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)

def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n))
def distribute(exprs: Expression*)(n: Int): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)

def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ object CollapseRepartition extends Rule[LogicalPlan] {
RepartitionByExpression(exprs, child, numPartitions)
// Case 3
case Repartition(numPartitions, _, r: RepartitionByExpression) =>
r.copy(numPartitions = Some(numPartitions))
r.copy(numPartitions = numPartitions)
// Case 3
case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
RepartitionByExpression(exprs, child, numPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
Sort(sort.asScala.map(visitSortItem), global = false, query)
} else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
// DISTRIBUTE BY ...
RepartitionByExpression(expressionList(distributeBy), query)
withRepartitionByExpression(ctx, expressionList(distributeBy), query)
} else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
// SORT BY ... DISTRIBUTE BY ...
Sort(
sort.asScala.map(visitSortItem),
global = false,
RepartitionByExpression(expressionList(distributeBy), query))
withRepartitionByExpression(ctx, expressionList(distributeBy), query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
// CLUSTER BY ...
val expressions = expressionList(clusterBy)
Sort(
expressions.map(SortOrder(_, Ascending)),
global = false,
RepartitionByExpression(expressions, query))
withRepartitionByExpression(ctx, expressions, query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
// [EMPTY]
query
Expand All @@ -273,6 +273,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}

/**
* Create a clause for DISTRIBUTE BY.
*/
protected def withRepartitionByExpression(
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
throw new ParseException("DISTRIBUTE BY is not supported", ctx)
}

/**
* Create a logical plan using a query specification.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,18 +844,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
* information about the number of partitions during execution. Used when a specific ordering or
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
* If `numPartitions` is not specified, the number of partitions will be the number set by
* `spark.sql.shuffle.partitions`.
*/
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Option[Int] = None) extends UnaryNode {
numPartitions: Int) extends UnaryNode {

numPartitions match {
case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.")
case None => // Ok
}
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")

override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import java.util.TimeZone

import org.scalatest.ShouldMatchers

import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
}

test("pull out nondeterministic expressions from RepartitionByExpression") {
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
Project(testRelation.output,
RepartitionByExpression(Seq(projected.toAttribute),
Project(testRelation.output :+ projected, testRelation)))
Project(testRelation.output :+ projected, testRelation),
numPartitions = 10))
checkAnalysis(plan, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest {
val orderSortDistrClusterClauses = Seq(
("", basePlan),
(" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
(" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
(" distribute by a, b", basePlan.distribute('a, 'b)()),
(" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)),
(" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc))
(" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc))
)

orderSortDistrClusterClauses.foreach {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three test cases are moved to SparkSqlParserSuite.scala

Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions)
}

/**
Expand All @@ -2425,7 +2425,8 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
RepartitionByExpression(
partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ import scala.collection.JavaConverters._
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode

import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{CreateTable, _}
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.StructType

/**
* Concrete parser for Spark SQL statements.
Expand Down Expand Up @@ -1441,4 +1442,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
reader, writer,
schemaLess)
}

/**
* Create a clause for DISTRIBUTE BY.
*/
override protected def withRepartitionByExpression(
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil

Expand Down Expand Up @@ -414,9 +412,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
case logical.RepartitionByExpression(expressions, child, numPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
expressions, numPartitions), planLater(child)) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
Expand All @@ -36,7 +38,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType
*/
class SparkSqlParserSuite extends PlanTest {

private lazy val parser = new SparkSqlParser(new SQLConf)
val newConf = new SQLConf
private lazy val parser = new SparkSqlParser(newConf)

/**
* Normalizes plans:
Expand Down Expand Up @@ -251,4 +254,29 @@ class SparkSqlParserSuite extends PlanTest {
assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
}

test("query organization") {
// Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
val baseSql = "select * from t"
val basePlan =
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("t")))

assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
}
}