Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.execution.local.LocalNode
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
Expand All @@ -38,7 +39,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
private[joins] sealed trait HashedRelation {
private[execution] sealed trait HashedRelation {
def get(key: InternalRow): Seq[InternalRow]

// This is a helper method to implement Externalizable, and is used by
Expand Down Expand Up @@ -111,7 +112,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.


private[joins] object HashedRelation {
private[execution] object HashedRelation {

def apply(
input: Iterator[InternalRow],
Expand Down Expand Up @@ -163,6 +164,63 @@ private[joins] object HashedRelation {
new GeneralHashedRelation(hashTable)
}
}

/**
* Consume a `LocalNode` to create a `HashedRelation`. Because `LocalNode` is not an `Iterator`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use [[LocalNode]] and [[HashedRelation]] instead

* we cannot use `apply` directly. Moreover, to avoid creating another layer of `Iterator`, we
* have to duplicate most codes of `apply` here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think implementing the wrapper is better since it's not very complicated. Duplicate code in general is really bad and hard to maintain. We can have something like the following in LocalNode

def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)

then provide the dummy SQLMetrics.nullLongMetric

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote some code for this. Feel free to steal or come up with something better. (not tested!)

/**
 * An thin wrapper around a [[LocalNode]] that provides an iterator interface.
 */
private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] {
  private var nextRow: InternalRow = _

  override def hasNext: Boolean = {
    if (nextRow == null) {
      val res = localNode.next()
      if (res) {
        nextRow = localNode.fetch()
      }
      res
    } else {
      true
    }
  }

  override def next(): InternalRow = {
    if (hasNext) {
      val res = nextRow
      nextRow = null
      res
    } else {
      throw new NoSuchElementException
    }
  }
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Added LocalNodeIterator to this PR :)

*
* Note: the default parameter is in conflict with overloading. So it uses a different method than
* rather than `apply`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this method must be called something other than apply because of default parameter overloading conflicts.

*/
def createLocalHashedRelation(
input: LocalNode,
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {

if (keyGenerator.isInstanceOf[UnsafeProjection]) {
return UnsafeHashedRelation(
input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}

// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null

// Whether the join key is unique. If the key is unique, we can convert the underlying
// hash map into one specialized for this.
var keyIsUnique = true

// Create a mapping of buildKeys -> rows
while (input.next()) {
currentRow = input.fetch()
val rowKey = keyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[InternalRow]()
hashTable.put(rowKey.copy(), newMatchList)
newMatchList
} else {
keyIsUnique = false
existingMatchList
}
matchList += currentRow.copy()
}
}

if (keyIsUnique) {
val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
uniqHashTable.put(entry.getKey, entry.getValue()(0))
}
new UniqueKeyHashedRelation(uniqHashTable)
} else {
new GeneralHashedRelation(hashTable)
}
}
}

/**
Expand Down Expand Up @@ -362,7 +420,7 @@ private[joins] final class UnsafeHashedRelation(
}
}

private[joins] object UnsafeHashedRelation {
private[execution] object UnsafeHashedRelation {

def apply(
input: Iterator[InternalRow],
Expand Down Expand Up @@ -393,4 +451,32 @@ private[joins] object UnsafeHashedRelation {

new UnsafeHashedRelation(hashTable)
}

def apply(
input: LocalNode,
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {

// Use a Java hash table here because unsafe maps expect fixed size records
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)

// Create a mapping of buildKeys -> rows
while (input.next()) {
val unsafeRow = input.fetch().asInstanceOf[UnsafeRow]
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[UnsafeRow]()
hashTable.put(rowKey.copy(), newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += unsafeRow.copy()
}
}

new UnsafeHashedRelation(hashTable)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection}

case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {

override def output: Seq[Attribute] = child.output

private[this] var convertToSafe: Projection = _

override def open(): Unit = {
child.open()
convertToSafe = FromUnsafeProjection(child.output.map(_.dataType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: could be child.schema

}

override def next(): Boolean = child.next()

override def fetch(): InternalRow = convertToSafe(child.fetch())

override def close(): Unit = child.close()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection}

case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {

override def output: Seq[Attribute] = child.output

private[this] var convertToUnsafe: Projection = _

override def open(): Unit = {
child.open()
convertToUnsafe = UnsafeProjection.create(child.schema)
}

override def next(): Boolean = child.next()

override def fetch(): InternalRow = convertToUnsafe(child.fetch())

override def close(): Unit = child.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate


case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode {
case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode)
extends UnaryLocalNode(conf) {

private[this] var predicate: (InternalRow) => Boolean = _

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins._

case class HashJoinNode (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you add a comment here that says much of this code is similar to HashJoin#hashJoin?

conf: SQLConf,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: LocalNode,
right: LocalNode) extends BinaryLocalNode(conf) {

private[this] lazy val (buildNode, streamedNode) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}

private[this] lazy val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be

private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = {
  buildSide match {
    case BuildLeft => (left, leftKeys, right, rightKeys)
    case BuildRight => (right, rightKeys, left, leftKeys)
  }
}

:)


override def output: Seq[Attribute] = left.output ++ right.output

private[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(schema))
}

private[this] def buildSideKeyGenerator: Projection =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add { } around this method and streamSideKeyGenerator?

if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildNode.output)
} else {
newMutableProjection(buildKeys, buildNode.output)()
}

private[this] def streamSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedNode.output)
} else {
newMutableProjection(streamedKeys, streamedNode.output)()
}

private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean?

private[this] var joinRow: JoinedRow = _
private[this] var resultProjection: (InternalRow) => InternalRow = _

private[this] var hashed: HashedRelation = _
private[this] var joinKeys: Projection = _
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you put all the vars before defs so it's easy to find them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Moved them.


override def open(): Unit = {
buildNode.open()
hashed = HashedRelation.createLocalHashedRelation(buildNode, buildSideKeyGenerator)
streamedNode.open()
joinRow = new JoinedRow
resultProjection = {
if (isUnsafeMode) {
UnsafeProjection.create(schema)
} else {
identity[InternalRow]
}
}
joinKeys = streamSideKeyGenerator
}

override def next(): Boolean = {
if (currentMatchPosition != -1) {
currentMatchPosition += 1
if (currentMatchPosition < currentHashMatches.size) {
true
} else {
fetchNextMatch()
}
} else {
fetchNextMatch()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the following is functionally equivalent and easier to read:

currentMatchPosition += 1
if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) {
  fetchNextMatch()
} else {
  true
}

which says if we don't currently have matches, or we've already joined all of our existing matches, then fetch more matches.

}

private def fetchNextMatch(): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some java docs here:

/**
 * Populate `currentHashMatches` with build-side rows matching the next streamed row.
 * @return whether matches are found such that subsequent calls to `fetch` are valid.
 */

currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamedNode.next()) {
currentStreamedRow = streamedNode.fetch()
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashed.get(key)
}
}

if (currentHashMatches == null) {
false
} else {
currentMatchPosition = 0
true
}
}

override def fetch(): InternalRow = {
val ret = buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
resultProjection(ret)
}

override def close(): Unit = {
left.close()
right.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute


case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {
case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) {

private[this] var count = 0

Expand Down
Loading