Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,58 @@ class Analyzer(
throw new AnalysisException("Only one generator allowed per select clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))

case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
"expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))

case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw new AnalysisException("Only one generator allowed per aggregate clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))

case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
case AliasedGenerator(_, _, _) => true
case other => other.resolved
} && aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean flag true.
var generatorVisited = false

val projectExprs = Array.ofDim[NamedExpression](aggList.length)
val newAggList = aggList
.map(CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
.zipWithIndex
.flatMap {
case (AliasedGenerator(generator, names, outer), idx) =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(!generatorVisited, "More than one generator found in aggregate.")
generatorVisited = true

val newGenChildren: Seq[Expression] = generator.children.zipWithIndex.map {
case (e, idx) => if (e.foldable) e else Alias(e, s"_gen_input_${idx}")()
}
val newGenerator = {
val g = generator.withNewChildren(newGenChildren.map { e =>
if (e.foldable) e else e.asInstanceOf[Alias].toAttribute
}).asInstanceOf[Generator]
if (outer) GeneratorOuter(g) else g
}
val newAliasedGenerator = if (names.length == 1) {
Alias(newGenerator, names(0))()
} else {
MultiAlias(newGenerator, names)
}
projectExprs(idx) = newAliasedGenerator
newGenChildren.filter(!_.foldable).asInstanceOf[Seq[NamedExpression]]
case (other, idx) =>
projectExprs(idx) = other.toAttribute
other :: Nil
}

val newAgg = Aggregate(groupList, newAggList, child)
Project(projectExprs.toList, newAgg)

case p @ Project(projectList, child) =>
// Holds the resolved generator, if one exists in the project list.
var resolvedGenerator: Generate = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,40 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
Row(1, null) :: Row(2, null) :: Nil)
}

test("generator in aggregate expression") {
withTempView("t1") {
Seq((1, 1), (1, 2), (2, 3)).toDF("c1", "c2").createTempView("t1")
checkAnswer(
sql("select explode(array(min(c2), max(c2))) from t1"),
Row(1) :: Row(3) :: Nil
)
checkAnswer(
sql("select posexplode(array(min(c2), max(c2))) from t1 group by c1"),
Row(0, 1) :: Row(1, 2) :: Row(0, 3) :: Row(1, 3) :: Nil
)
// test generator "stack" which require foldable argument
checkAnswer(
sql("select stack(2, min(c1), max(c1), min(c2), max(c2)) from t1"),
Row(1, 2) :: Row(1, 3) :: Nil
)

val msg1 = intercept[AnalysisException] {
sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1")
}.getMessage
assert(msg1.contains("Generators are not supported when it's nested in expressions"))

val msg2 = intercept[AnalysisException] {
sql(
"""select
| explode(array(min(c2), max(c2))),
| posexplode(array(min(c2), max(c2)))
|from t1 group by c1
""".stripMargin)
}.getMessage
assert(msg2.contains("Only one generator allowed per aggregate clause"))
}
}
}

case class EmptyGenerator() extends Generator {
Expand Down