Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ private static StructType structFromTypes(DataType[] format) {
return new StructType(fields);
}

private static StructType structFromAttributes(List<Attribute> format) {
public static StructType structFromAttributes(List<Attribute> format) {
StructField[] fields = new StructField[format.size()];
int i = 0;
for (Attribute attribute: format) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ package com.nvidia.spark.rapids

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq

import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType}

/**
* Utility class with methods for calculating various metrics about GPU memory usage
* prior to allocation.
* prior to allocation, along with some operations with batches.
*/
object GpuBatchUtils {

Expand Down Expand Up @@ -175,4 +179,37 @@ object GpuBatchUtils {
bytes
}
}

/**
* Concatenate the input batches into a single one.
* The caller is responsible for closing the returned batch.
*
* @param spillBatches the batches to be concatenated, will be closed after the call
* returns.
* @return the concatenated SpillableColumnarBatch or None if the input is empty.
*/
def concatSpillBatchesAndClose(
spillBatches: Seq[SpillableColumnarBatch]): Option[SpillableColumnarBatch] = {
val retBatch = if (spillBatches.length >= 2) {
// two or more batches, concatenate them
val (concatTable, types) = RmmRapidsRetryIterator.withRetryNoSplit(spillBatches) { _ =>
withResource(spillBatches.safeMap(_.getColumnarBatch())) { batches =>
val batchTypes = GpuColumnVector.extractTypes(batches.head)
withResource(batches.safeMap(GpuColumnVector.from)) { tables =>
(Table.concatenate(tables: _*), batchTypes)
}
}
}
// Make the concatenated table spillable.
withResource(concatTable) { _ =>
SpillableColumnarBatch(GpuColumnVector.from(concatTable, types),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
}
} else if (spillBatches.length == 1) {
// only one batch
spillBatches.head
} else null

Option(retBatch)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ package org.apache.spark.sql.rapids.execution
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, RmmRapidsRetryIterator, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource}
import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._

Expand All @@ -41,27 +40,7 @@ object GpuSubPartitionHashJoin {
*/
def concatSpillBatchesAndClose(
spillBatches: Seq[SpillableColumnarBatch]): Option[SpillableColumnarBatch] = {
val retBatch = if (spillBatches.length >= 2) {
// two or more batches, concatenate them
val (concatTable, types) = RmmRapidsRetryIterator.withRetryNoSplit(spillBatches) { _ =>
withResource(spillBatches.safeMap(_.getColumnarBatch())) { batches =>
val batchTypes = GpuColumnVector.extractTypes(batches.head)
withResource(batches.safeMap(GpuColumnVector.from)) { tables =>
(Table.concatenate(tables: _*), batchTypes)
}
}
}
// Make the concatenated table spillable.
withResource(concatTable) { _ =>
SpillableColumnarBatch(GpuColumnVector.from(concatTable, types),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
}
} else if (spillBatches.length == 1) {
// only one batch
spillBatches.head
} else null

Option(retBatch)
GpuBatchUtils.concatSpillBatchesAndClose(spillBatches)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.rapids.execution.GpuSubPartitionHashJoin
import org.apache.spark.sql.rapids.execution.python.shims.GpuPythonArrowOutput
import org.apache.spark.sql.rapids.execution.python.shims.GpuBasePythonRunner
import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -197,7 +196,7 @@ private[python] object BatchGroupUtils {
def executePython[IN](
pyInputIterator: Iterator[IN],
output: Seq[Attribute],
pyRunner: GpuPythonRunnerBase[IN],
pyRunner: GpuBasePythonRunner[IN],
outputRows: GpuMetric,
outputBatches: GpuMetric): Iterator[ColumnarBatch] = {
val context = TaskContext.get()
Expand Down Expand Up @@ -396,7 +395,7 @@ private[python] object BatchGroupedIterator {
class CombiningIterator(
inputBatchQueue: BatchQueue,
pythonOutputIter: Iterator[ColumnarBatch],
pythonArrowReader: GpuPythonArrowOutput,
pythonArrowReader: GpuArrowOutput,
numOutputRows: GpuMetric,
numOutputBatches: GpuMetric) extends Iterator[ColumnarBatch] {

Expand Down Expand Up @@ -456,7 +455,7 @@ class CombiningIterator(
pendingInput = Some(second)
}

val ret = GpuSubPartitionHashJoin.concatSpillBatchesAndClose(buf.toSeq)
val ret = GpuBatchUtils.concatSpillBatchesAndClose(buf.toSeq)
// "ret" should be non empty because we checked the buf is not empty ahead.
withResource(ret.get) { concatedScb =>
concatedScb.getColumnarBatch()
Expand Down Expand Up @@ -596,3 +595,4 @@ class CoGroupedIterator(
keyOrdering.compare(leftKeyRow, rightKeyRow)
}
}

Loading