diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 58f98d529ab5..8847595ed987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1918,14 +1918,9 @@ class Analyzer( case p: Project => p case f: Filter => f - // todo: It's hard to write a general rule to pull out nondeterministic expressions - // from LogicalPlan, currently we only do it for UnaryNode which has same output - // schema with its child. - case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - } + case p: UnaryNode if p.expressions.exists(!_.deterministic) => + val nondeterExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { case n: Nondeterministic => n } leafNondeterministic.map { e => val ne = e match { case n: NamedExpression => n @@ -1934,11 +1929,21 @@ class Analyzer( new TreeNodeRef(e) -> ne } }.toMap - val newPlan = p.transformExpressions { case e => - nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + + println("map is " + nondeterExprs) + + val newChild = Project(p.child.output ++ nondeterExprs.values, p.child) + + val newPlan = p.transformExpressions { + case e if nondeterExprs.contains(new TreeNodeRef(e)) => + nondeterExprs(new TreeNodeRef(e)).toAttribute + }.withNewChildren(newChild :: Nil) + + if (newPlan.output != p.output) { + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } else { + newPlan } - val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) - Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala new file mode 100644 index 000000000000..e46d8e8bef54 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * Test suite for moving non-deterministic expressions into Project. + */ +class PullOutNondeterministicSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val b = 'b.int + private lazy val r = LocalRelation(a, b) + private lazy val rnd = Rand(10).as('_nondeterministic) + private lazy val rndref = rnd.toAttribute + + test("no-op on filter") { + checkAnalysis( + r.where(Rand(10) > Literal(1.0)), + r.where(Rand(10) > Literal(1.0)) + ) + } + + test("sort") { + checkAnalysis( + r.sortBy(SortOrder(Rand(10), Ascending)), + r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b) + ) + } + + test("aggregate") { + checkAnalysis( + r.groupBy(rnd)(rnd), + r.select(a, b, rnd).groupBy(rndref)(rndref) + ) + } +}