Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,20 @@ class SimpleTestOptimizer extends Optimizer(

/**
* Pushes projects down beneath Sample to enable column pruning with sampling.
* This rule is only doable when the projects don't add new attributes.
*/
object PushProjectThroughSample extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we merge this rule into ColumnPruning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, looks like ColumnPruning already handles it, can we just remove this rule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which part you mean? I don't see ColumnPruning handling Sample?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the last case in ColumnPruning, it will generate a new Project under Sample

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah. I will update this. At least one optimizer test uses this rule. The test should be changed too.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down projection into sample
case Project(projectList, Sample(lb, up, replace, seed, child)) =>
case p @ Project(projectList, Sample(lb, up, replace, seed, child))
if !hasNewOutput(projectList, p.child.output) =>
Sample(lb, up, replace, seed, Project(projectList, child))()
}
private def hasNewOutput(
projectList: Seq[NamedExpression],
childOutput: Seq[Attribute]): Boolean = {
projectList.exists(p => !childOutput.exists(_.semanticEquals(p)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to understand what the code does -- two exists and negation. Can you "untangle" it?

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,22 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer.analyze)
}

test("don't push project down into sample if project brings new attributes") {
val x = testRelation.subquery('x)
val originalQuery =
Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa)

val originalQueryAnalyzed =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery))

val optimized = Optimize.execute(originalQueryAnalyzed)

val correctAnswer =
Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa)

comparePlans(optimized, correctAnswer.analyze)
}

test("aggregate: push down filter when filter on group by expression") {
val originalQuery = testRelation
.groupBy('a)('a, count('b) as 'c)
Expand Down
29 changes: 29 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,35 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
3, 17, 27, 58, 62)
}

test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") {
val udfOne = spark.udf.register("udfOne", (n: Int) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you spark.udf.register not udf directly?

val udfOne = udf { n: Int => ... }

if (n == 1) {
throw new RuntimeException("udfOne shouldn't see swid=1!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use require? generally RuntimeException isn't used directly. Really minor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've updated it.

} else {
1
}
})

val d = Seq(
(0, "string0"),
(1, "string1"),
(2, "string2"),
(3, "string3"),
(4, "string4"),
(5, "string5"),
(6, "string6"),
(7, "string7"),
(8, "string8"),
(9, "string9")
)
val df = spark.createDataFrame(d).toDF("swid", "stringData")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d.toDF(...) should work too, shouldn't it?

val sampleDF = df.sample(false, 0.7, 50)
// After sampling, sampleDF doesn't contain swid=1.
assert(!sampleDF.select("swid").collect.contains(1))
// udfOne should not encounter swid=1.
sampleDF.select(udfOne($"swid")).collect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you're calling collect to trigger assert, aren't you? If so, why don't you return true/false to denote it and do assert here instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sample will filter out the returned value of swid=1. So I simply call collect to verify if the exception will be thrown or not.

}

test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
Expand Down