Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.types.StructType

/**
Expand All @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {

protected val codegenEnabled: Boolean = conf.codegenEnabled

protected val unsafeEnabled: Boolean = conf.unsafeEnabled

lazy val schema: StructType = StructType.fromAttributes(output)

private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")

def output: Seq[Attribute]

/**
* Initializes the iterator state. Must be called before calling `next()`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.sql.execution.local

import java.util.Random

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}


/**
* Sample the dataset.
*
Expand Down Expand Up @@ -51,18 +50,15 @@ case class SampleNode(

override def open(): Unit = {
child.open()
val (sampler, _seed) = if (withReplacement) {
val random = new Random(seed)
val sampler =
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
// of DataFrame
random.nextLong())
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
} else {
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
}
sampler.setSeed(_seed)
sampler.setSeed(seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing I had to remove this to make testing deterministic. Looking at this further I still don't see the point of introducing another layer of randomness here. What change in behavior does this entail?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using DataFrame.sample to test SampleNode and it mocked the behavior of DataFrame.sample(withReplacement = true). Since you don't use DataFrame to test it now, I agree that we can remove this tricky logic.

iterator = sampler.sample(child.asIterator)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
}
// Close it eagerly since we don't need it.
child.close()
iterator = queue.iterator
iterator = queue.toArray.sorted(ord).iterator
}

override def next(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ object SparkPlanTest {
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan.transformExpressions {
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

/**
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
*/
private[local] case class DummyNode(
output: Seq[Attribute],
relation: LocalRelation,
conf: SQLConf)
extends LocalNode(conf) {

import DummyNode._

private var index: Int = CLOSED
private val input: Seq[InternalRow] = relation.data

def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
this(output, LocalRelation.fromProduct(output, data), conf)
}

def isOpen: Boolean = index != CLOSED

override def children: Seq[LocalNode] = Seq.empty

override def open(): Unit = {
index = -1
}

override def next(): Boolean = {
index += 1
index < input.size
}

override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
input(index)
}

override def close(): Unit = {
index = CLOSED
}
}

private object DummyNode {
val CLOSED: Int = Int.MinValue
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,33 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.dsl.expressions._


class ExpandNodeSuite extends LocalNodeTest {

import testImplicits._

test("expand") {
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
checkAnswer(
input,
node =>
ExpandNode(conf, Seq(
Seq(
input.col("key") + input.col("value"), input.col("key") - input.col("value")
).map(_.expr),
Seq(
input.col("key") * input.col("value"), input.col("key") / input.col("value")
).map(_.expr)
), node.output, node),
Seq(
(2, 0),
(1, 1),
(4, 0),
(4, 1),
(6, 0),
(9, 1),
(8, 0),
(16, 1),
(10, 0),
(25, 1)
).toDF().collect()
)
private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val inputNode = new DummyNode(kvIntAttributes, inputData)
val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
val resolvedNode = resolveExpressions(expandNode)
val expectedOutput = {
val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
firstHalf ++ secondHalf
}
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput.toSet === expectedOutput.toSet)
}

test("empty") {
testExpand()
}

test("basic") {
testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._

class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect()
)
class FilterNodeSuite extends LocalNodeTest {

private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val cond = 'k % 2 === 0
val inputNode = new DummyNode(kvIntAttributes, inputData)
val filterNode = new FilterNode(conf, cond, inputNode)
val resolvedNode = resolveExpressions(filterNode)
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}

test("empty") {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect()
)
testFilter()
}

test("basic") {
testFilter((1 to 100).map { i => (i, i) }.toArray)
}

}
Loading