Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.types.StructType

/**
Expand All @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {

protected val codegenEnabled: Boolean = conf.codegenEnabled

protected val unsafeEnabled: Boolean = conf.unsafeEnabled

lazy val schema: StructType = StructType.fromAttributes(output)

private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")

def output: Seq[Attribute]

/**
* Initializes the iterator state. Must be called before calling `next()`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql.execution.local

import java.util.Random

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, RandomSampler}


/**
* Sample the dataset.
Expand Down Expand Up @@ -51,18 +50,15 @@ case class SampleNode(

override def open(): Unit = {
child.open()
val (sampler, _seed) = if (withReplacement) {
val random = new Random(seed)
val sampler =
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
// of DataFrame
random.nextLong())
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
} else {
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
}
sampler.setSeed(_seed)
sampler.setSeed(seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing I had to remove this to make testing deterministic. Looking at this further I still don't see the point of introducing another layer of randomness here. What change in behavior does this entail?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using DataFrame.sample to test SampleNode and it mocked the behavior of DataFrame.sample(withReplacement = true). Since you don't use DataFrame to test it now, I agree that we can remove this tricky logic.

iterator = sampler.sample(child.asIterator)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ object SparkPlanTest {
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan.transformExpressions {
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

/**
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
*/
private[local] case class DummyNode(
output: Seq[Attribute],
relation: LocalRelation,
conf: SQLConf)
extends LocalNode(conf) {

import DummyNode._

private var index: Int = CLOSED
private val input: Seq[InternalRow] = relation.data

def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
this(output, LocalRelation.fromProduct(output, data), conf)
}

def isOpen: Boolean = index != CLOSED

override def children: Seq[LocalNode] = Seq.empty

override def open(): Unit = {
index = -1
}

override def next(): Boolean = {
index += 1
index < input.size
}

override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
input(index)
}

override def close(): Unit = {
index = CLOSED
}
}

private object DummyNode {
val CLOSED: Int = Int.MinValue
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._

class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect()
)
class FilterNodeSuite extends LocalNodeTest {

private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val cond = 'k % 2 === 0
val inputNode = new DummyNode(kvIntAttributes, inputData)
val filterNode = new FilterNode(conf, cond, inputNode)
val resolvedNode = resolveExpressions(filterNode)
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}

test("empty") {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect()
)
testFilter()
}

test("basic") {
testFilter((1 to 100).map { i => (i, i) }.toArray)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

package org.apache.spark.sql.execution.local

class IntersectNodeSuite extends LocalNodeTest {

import testImplicits._
class IntersectNodeSuite extends LocalNodeTest {

test("basic") {
val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")

checkAnswer2(
input1,
input2,
(node1, node2) => IntersectNode(conf, node1, node2),
input1.intersect(input2).collect()
)
val n = 100
val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray
val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray
val leftNode = new DummyNode(kvIntAttributes, leftData)
val rightNode = new DummyNode(kvIntAttributes, rightData)
val intersectNode = new IntersectNode(conf, leftNode, rightNode)
val expectedOutput = leftData.intersect(rightData)
val actualOutput = intersectNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext

class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
class LimitNodeSuite extends LocalNodeTest {

test("basic") {
checkAnswer(
testData,
node => LimitNode(conf, 10, node),
testData.limit(10).collect()
)
private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = {
val inputNode = new DummyNode(kvIntAttributes, inputData)
val limitNode = new LimitNode(conf, limit, inputNode)
val expectedOutput = inputData.take(limit)
val actualOutput = limitNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}

test("empty") {
checkAnswer(
emptyTestData,
node => LimitNode(conf, 10, node),
emptyTestData.limit(10).collect()
)
testLimit()
}

test("basic") {
testLimit((1 to 100).map { i => (i, i) }.toArray, 20)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,42 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.IntegerType

class LocalNodeSuite extends SparkFunSuite {
private val data = (1 to 100).toArray
class LocalNodeSuite extends LocalNodeTest {
private val data = (1 to 100).map { i => (i, i) }.toArray

test("basic open, next, fetch, close") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
assert(!node.isOpen)
node.open()
assert(node.isOpen)
data.foreach { i =>
data.foreach { case (k, v) =>
assert(node.next())
// fetch should be idempotent
val fetched = node.fetch()
assert(node.fetch() === fetched)
assert(node.fetch() === fetched)
assert(node.fetch().numFields === 1)
assert(node.fetch().getInt(0) === i)
assert(node.fetch().numFields === 2)
assert(node.fetch().getInt(0) === k)
assert(node.fetch().getInt(1) === v)
}
assert(!node.next())
node.close()
assert(!node.isOpen)
}

test("asIterator") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
val iter = node.asIterator
node.open()
data.foreach { i =>
data.foreach { case (k, v) =>
// hasNext should be idempotent
assert(iter.hasNext)
assert(iter.hasNext)
val item = iter.next()
assert(item.numFields === 1)
assert(item.getInt(0) === i)
assert(item.numFields === 2)
assert(item.getInt(0) === k)
assert(item.getInt(1) === v)
}
intercept[NoSuchElementException] {
iter.next()
Expand All @@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite {
}

test("collect") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
node.open()
val collected = node.collect()
assert(collected.size === data.size)
assert(collected.forall(_.size === 1))
assert(collected.map(_.getInt(0)) === data)
assert(collected.forall(_.size === 2))
assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data)
node.close()
}

}

/**
* A dummy [[LocalNode]] that just returns one row per integer in the input.
*/
private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) {
private var index = Int.MinValue

def this(input: Array[Int]) {
this(new SQLConf, input)
}

def isOpen: Boolean = {
index != Int.MinValue
}

override def output: Seq[Attribute] = {
Seq(AttributeReference("something", IntegerType)())
}

override def children: Seq[LocalNode] = Seq.empty

override def open(): Unit = {
index = -1
}

override def next(): Boolean = {
index += 1
index < input.size
}

override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
val values = Array(input(index).asInstanceOf[Any])
new GenericInternalRow(values)
}

override def close(): Unit = {
index = Int.MinValue
}
}
Loading