Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object GroupedIterator {
keyExpressions: Seq[Expression],
inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
if (input.hasNext) {
new GroupedIterator(input, keyExpressions, inputSchema)
new GroupedIterator(input.buffered, keyExpressions, inputSchema)
} else {
Iterator.empty
}
Expand Down Expand Up @@ -64,7 +64,7 @@ object GroupedIterator {
* @param inputSchema The schema of the rows in the `input` iterator.
*/
class GroupedIterator private(
input: Iterator[InternalRow],
input: BufferedIterator[InternalRow],
groupingExpressions: Seq[Expression],
inputSchema: Seq[Attribute])
extends Iterator[(InternalRow, Iterator[InternalRow])] {
Expand All @@ -83,10 +83,17 @@ class GroupedIterator private(

/** Holds a copy of an input row that is in the current group. */
var currentGroup = currentRow.copy()
var currentIterator: Iterator[InternalRow] = null

assert(keyOrdering.compare(currentGroup, currentRow) == 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we only use keyOrdering to do equality check, why not just use ==? The currentGroup and currentRow are from the same input, they must be both unsafe or safe, and == for UnsafeRow is faster than keyOrdering.compare.

cc @marmbrus

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the whole row, not just the key. This allows us to do the equality check on the key columns only (which might short circuit) instead of doing a full projection on each row to extract the key columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry I missed it

var currentIterator = createGroupValuesIterator()

// Return true if we already have the next iterator or fetching a new iterator is successful.
/**
* Return true if we already have the next iterator or fetching a new iterator is successful.
*
* Note that, if we get the iterator by `next`, we should consume it before call `hasNext`,
* because we will consume the input data to skip to next group while fetching a new iterator,
* thus make the previous iterator empty.
*/
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator

def next(): (InternalRow, Iterator[InternalRow]) = {
Expand All @@ -96,46 +103,64 @@ class GroupedIterator private(
ret
}

def fetchNextGroupIterator(): Boolean = {
if (currentRow != null || input.hasNext) {
val inputIterator = new Iterator[InternalRow] {
// Return true if we have a row and it is in the current group, or if fetching a new row is
// successful.
def hasNext = {
(currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
fetchNextRowInGroup()
}
private def fetchNextGroupIterator(): Boolean = {
assert(currentIterator == null)

if (currentRow == null && input.hasNext) {
currentRow = input.next()
}

if (currentRow == null) {
// These is no data left, return false.
false
} else {
// Skip to next group.
while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
currentRow = input.next()
}

if (keyOrdering.compare(currentGroup, currentRow) == 0) {
// We are in the last group, there is no more groups, return false.
false
} else {
// Now the `currentRow` is the first row of next group.
currentGroup = currentRow.copy()
currentIterator = createGroupValuesIterator()
true
}
}
}

private def createGroupValuesIterator(): Iterator[InternalRow] = {
new Iterator[InternalRow] {
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()

def next(): InternalRow = {
assert(hasNext)
val res = currentRow
currentRow = null
res
}

def fetchNextRowInGroup(): Boolean = {
if (currentRow != null || input.hasNext) {
private def fetchNextRowInGroup(): Boolean = {
assert(currentRow == null)

if (input.hasNext) {
// The inner iterator should NOT consume the input into next group, here we use `head` to
// peek the next input, to see if we should continue to process it.
if (keyOrdering.compare(currentGroup, input.head) == 0) {
// Next input is in the current group. Continue the inner iterator.
currentRow = input.next()
if (keyOrdering.compare(currentGroup, currentRow) == 0) {
// The row is in the current group. Continue the inner iterator.
true
} else {
// We got a row, but its not in the right group. End this inner iterator and prepare
// for the next group.
currentIterator = null
currentGroup = currentRow.copy()
false
}
true
} else {
// There is no more input so we are done.
// Next input is not in the right group. End this inner iterator.
false
}
}

def next(): InternalRow = {
assert(hasNext) // Ensure we have fetched the next row.
val res = currentRow
currentRow = null
res
} else {
// There is no more data, return false.
false
}
}
currentIterator = inputIterator
true
} else {
false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add the apache header


import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

test("basic") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)

val result = grouped.map {
case (key, data) =>
assert(key.numFields == 1)
key.getInt(0) -> data.map(encoder.fromRow).toSeq
}.toSeq

assert(result ==
1 -> Seq(input(0), input(1)) ::
2 -> Seq(input(2)) :: Nil)
}

test("group by 2 columns") {
val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
val encoder = RowEncoder(schema)

val input = Seq(
Row(1, 2L, "a"),
Row(1, 2L, "b"),
Row(1, 3L, "c"),
Row(2, 1L, "d"),
Row(3, 2L, "e"))

val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

val result = grouped.map {
case (key, data) =>
assert(key.numFields == 2)
(key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
}.toSeq

assert(result ==
(1, 2L, Seq(input(0), input(1))) ::
(1, 3L, Seq(input(2))) ::
(2, 1L, Seq(input(3))) ::
(3, 2L, Seq(input(4))) :: Nil)
}

test("do nothing to the value iterator") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)

assert(grouped.length == 2)
}
}