Skip to content

Conversation

@BryanCutler
Copy link
Owner

@BryanCutler BryanCutler commented Feb 8, 2017

What changes were proposed in this pull request?

Changed conversion to be iterator based, instead of Array[InternalRows]

Moved Dataset conversion to partitions, then collect as byte array

Added some pydocs

How was this patch tested?

Passing tests now, need to test conversion with 1 partition to match json file

@BryanCutler
Copy link
Owner Author

Benchmarks seem promising so far (unless something is terribly wrong)

collect as internal rows                                                        
count    50.000000
mean      0.213734
std       0.038923
min       0.159922
25%       0.188007
50%       0.205784
75%       0.224724
max       0.392772
dtype: float64
toPandas with arrow
count    50.000000
mean      0.064689
std       0.035405
min       0.050057
25%       0.055823
50%       0.059600
75%       0.062314
max       0.305428
dtype: float64
toPandas without arrow
count    50.000000                                                              
mean      2.174661
std       0.116211
min       2.011181
25%       2.098154
50%       2.138159
75%       2.226889
max       2.515083
dtype: float64



//////////////////////
private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) extends SeekableByteChannel {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from somewhere in arrow test cases, is there an existing class that might work or maybe add to a public util package?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrowReader?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I use this class in ArrowReader because it requires a SeekableByteChannel, but the ones available in java.nio aren't seekable

val buffers = fieldAndBuf._2.flatten

val recordBatch = new ArrowRecordBatch(rows.length,
val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this seem acceptable to get the row length for creating an ArrowRecordBatch or better to keep a counter myself?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need another counter here.

When is fieldNodes empty?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably never be empty, but might be good to keep the check just in case

override def init(initialSize: Int): Unit = {
override def init(initialSize: Option[Int]): this.type = {
initialSize.foreach(valueVector.setInitialCapacity)
valueVector.allocateNew()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this was never used before, would it increase performance much to set the exact capacity? Although I'm not sure its possible when using Iterator[InternalRow] without making a first pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iterator, we cannot know the size. I would just set a constant initial capacity.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll remove the option to set an initial capacity now. The default capacity seems to be fine.

assert(arrowPayload.nonEmpty)
arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0))
assert(arrowPayload.isEmpty)
// TODO: test empty partitions
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ArrowPayload is an Iterator, I guess it makes sense that an empty DataFrame will convert to an empty Iterator. However, if a partition is zero length, then it will convert that to a zero length ArrowRecordBatch. Sound right?


val arrowPayload = df.collectAsArrow(Some(converter))
// TODO: repartition because can only validate one batch
val arrowPayload = df.repartition(1).collectAsArrow(Some(converter))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests fail if using more than 1 partition because it will produce more than 1 ArrowRecordBatches and I think VectorLoader will just load 1 batch. Is there a way to combine batches or load more than 1 in Java?

@BryanCutler
Copy link
Owner Author

@wesm @icexelloss this is a first stab at running the Dataset conversion at partition level so it can be done by executors. It seems promising so far, but I ran into some issues with the Java tests (see above) and Python tests, so I was hoping to get your input.

Python tests fail when using a String column and there is an empty partition. In arrow-cpp adapter.cc calls LoadBinary here and tries GetBuffer even though the length is 0. Is is possible to put the same checks for this as in LoadPrimitive?

@wesm
Copy link

wesm commented Feb 9, 2017

hi @BryanCutler I created https://issues.apache.org/jira/browse/ARROW-544, I will try to reproduce and fix

val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) =>
internalRowToArrowBuf(rows, ordinal, field, allocator)
allocator: RootAllocator,
initialSize: Option[Int]): ArrowRecordBatch = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of initialSize for an iterator orientated implementation

try {
val collectedRows = queryExecution.executedPlan.executeCollect()
cnvtr.internalRowsToPayload(collectedRows, this.schema)
val rowIter = queryExecution.executedPlan.executeToIterator()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, was looking at an old commit. Please ignore this.

.init(initialSize)
}

rowIter.foreach { row =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we even discussed this before I think.. Even here it might add a little overhead to wrap in a function object, so I'll change it to a while loop just to be sure.


def toArrowBatchBytes(): RDD[Array[Byte]] = {
val schema_captured = this.schema
queryExecution.toRdd.mapPartitions { iter =>
Copy link
Collaborator

@icexelloss icexelloss Feb 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try to follow:

https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala#L224

Rename this to "getArrowByteArrayRdd" and move to SparkPlan?

use mapPartitionsInternal might be better

We can also try compression here

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll look into that! We can probably use mapPartitionsInternal because I don't think the closure needs cleaning..

Not sure if it should be moved to SparkPlan, I'll let others weigh in on that.. but it probably shouldn't be public.



//////////////////////
private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) extends SeekableByteChannel {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrowReader?

@icexelloss
Copy link
Collaborator

Nice work!

@icexelloss
Copy link
Collaborator

collecting as arrow batch is 3x faster than collecting as internal rows!

Honestly I don't understand why it is the case, @wesm

@wesm
Copy link

wesm commented Feb 9, 2017

@icexelloss it's because the conversion is being parallelized on task executors -- serializing Array[InternalRow] is not free. Presumably moving the serialized byte arrays adds comparatively overhead

@icexelloss
Copy link
Collaborator

I might be missing sth here.

Collect as internal rows (SparkPlan.executeCollect) doesn't do any arrow conversion, it does the normal unsafe row serialization which is just pretty much just sending the underlying byte array of an unsafe row. I wonder why it's 3x slower than doing the arrow conversion and then sending the arrow byte array.

Overhead of compression maybe?

@wesm
Copy link

wesm commented Feb 9, 2017

@icexelloss the conversion from Array[InternalRow] to Arrow record batches is being parallelized. Before, it was serial and single-threaded

@wesm
Copy link

wesm commented Feb 9, 2017

I'm getting a 38% performance improvement on the test case I have in my talk in about an hour -- I'll include the updated benchmarks (I already had slides about executor-local serialization as a "TODO"), so thank you

@BryanCutler
Copy link
Owner Author

Thanks for the review @icexelloss and @wesm! I was a little surprised too that the benchmarks are showing a speedup over collecting internal rows, but I think it's mostly due to the data being transferred as large batches instead of individual rows.

…al, updated Arrow and using microsecond timestamps
@icexelloss
Copy link
Collaborator

@BryanCutler , I have issues running code from your branch stc_toPandas_with_arrow.

Do you need to build pyarrow locally to make it work?

@BryanCutler
Copy link
Owner Author

BryanCutler commented Feb 10, 2017 via email

@icexelloss
Copy link
Collaborator

icexelloss commented Feb 10, 2017

No, I think I don't have the updated pyarrow ( I was using pyarrow-0.1 from condaforge before) so it's complaining about "from pyarrow.table import concat_tables"

Then I tried to install pyarrow locally, but hit:
(python setup.py build_ext --inplace)
[ 2%] Building CXX object CMakeFiles/pyarrow.dir/src/pyarrow/adapters/pandas.cc.o
/home/icexelloss/workspace/arrow/python/src/pyarrow/adapters/pandas.cc:52:14: error: ‘arrow::DictionaryType’ has not been declared
using arrow::DictionaryType;

I have arrow cpp installed in /usr/local:
(spark-dev) icexelloss@icexelloss-VirtualBox:/workspace/arrow/python$ echo $LD_LIBRARY_PATH
/usr/local/lib
(spark-dev) icexelloss@icexelloss-VirtualBox:
/workspace/arrow/python$ ls $LD_LIBRARY_PATH
libarrow.a libarrow_io.a libarrow_io.so libarrow_ipc.a libarrow_ipc.so libarrow_jemalloc.a libarrow_jemalloc.so libarrow.so pkgconfig python2.7 python3.5

I wonder if you have seen this when building pyarrow?

@BryanCutler
Copy link
Owner Author

BryanCutler commented Feb 10, 2017 via email

@wesm
Copy link

wesm commented Feb 10, 2017

@icexelloss the issue is that the libarrow you have installed is outdated. I have bash functions like these to help automate things

function parquet_cpp_update {
    mkdir -p ~/code/parquet-cpp/library-build
    pushd ~/code/parquet-cpp/library-build
    rm -rf *
    cmake -DCMAKE_INSTALL_PREFIX=$TP_DIR \
          -DCMAKE_CLANG_OPTIONS=$CMAKE_CLANG_OPTIONS \
          -DCMAKE_BUILD_TYPE=$TOOLCHAIN_BUILD_TYPE \
          -DPARQUET_ARROW=on \
          -DPARQUET_ZLIB_VENDORED=on \
          -DPARQUET_BUILD_TESTS=OFF \
          -DPARQUET_BUILD_EXECUTABLES=OFF \
          ..
    make -j8
    make install
    popd
}

function arrow_cpp_update {
    mkdir -p ~/code/arrow/cpp/library-build
    pushd ~/code/arrow/cpp/library-build
    rm -rf *
    cmake -DCMAKE_INSTALL_PREFIX=$TP_DIR \
          -DCMAKE_BUILD_TYPE=$TOOLCHAIN_BUILD_TYPE \
          -DARROW_BOOST_USE_SHARED=off \
          -DARROW_BUILD_BENCHMARKS=off \
          -DARROW_JEMALLOC=off \
          -DARROW_BUILD_TESTS=off \
          -DCMAKE_BUILD_TYPE=$TOOLCHAIN_BUILD_TYPE ..
    make -j8 VERBOSE=1
    make install
    popd
}

function update_tp_toolchain {
    arrow_cpp_update
    parquet_cpp_update
}

you may wish to remove the rm -rf * if you don't have any thirdparty $*_HOME variables set

@wesm
Copy link

wesm commented Feb 10, 2017

If you are building from source, you must be careful about release vs. debug builds (because the debug builds are much slower). I maintain both debug and release toolchains and have helper functions to switch between them:

function debug() {
    export TP_DIR=$DEBUG_TP_DIR
    export TOOLCHAIN_BUILD_TYPE=debug
    set_build_env
}

function release() {
    export TP_DIR=$RELEASE_TP_DIR
    export TOOLCHAIN_BUILD_TYPE=release
    set_build_env
}

@icexelloss
Copy link
Collaborator

Thanks @wesm. I will give it a try.

@BryanCutler
Copy link
Owner Author

Hi @icexelloss , just wondering if you were able to run this branch yet and confirm it works for you too?

BryanCutler added a commit that referenced this pull request Feb 23, 2017
arrow conversion done at partition by executors

some cleanup of APIs, made tests complete for non-complex data types

closes #23
BryanCutler pushed a commit that referenced this pull request Nov 10, 2020
### What changes were proposed in this pull request?
Push down filter through expand.  For case below:
```
create table t1(pid int, uid int, sid int, dt date, suid int) using parquet;
create table t2(pid int, vs int, uid int, csid int) using parquet;

SELECT
       years,
       appversion,
       SUM(uusers) AS users
FROM   (SELECT
               Date_trunc('year', dt)          AS years,
               CASE
                 WHEN h.pid = 3 THEN 'iOS'
                 WHEN h.pid = 4 THEN 'Android'
                 ELSE 'Other'
               END                             AS viewport,
               h.vs                            AS appversion,
               Count(DISTINCT u.uid)           AS uusers
               ,Count(DISTINCT u.suid)         AS srcusers
        FROM   t1 u
               join t2 h
                 ON h.uid = u.uid
        GROUP  BY 1,
                  2,
                  3) AS a
WHERE  viewport = 'iOS'
GROUP  BY 1,
          2
```

Plan. before this pr:
```
== Physical Plan ==
*(5) HashAggregate(keys=[years#30, appversion#32], functions=[sum(uusers#33L)])
+- Exchange hashpartitioning(years#30, appversion#32, 200), true, [id=apache#251]
   +- *(4) HashAggregate(keys=[years#30, appversion#32], functions=[partial_sum(uusers#33L)])
      +- *(4) HashAggregate(keys=[date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12], functions=[count(if ((gid#44 = 1)) u.`uid`apache#47 else null)])
         +- Exchange hashpartitioning(date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12, 200), true, [id=apache#246]
            +- *(3) HashAggregate(keys=[date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12], functions=[partial_count(if ((gid#44 = 1)) u.`uid`apache#47 else null)])
               +- *(3) HashAggregate(keys=[date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12, u.`uid`apache#47, u.`suid`apache#48, gid#44], functions=[])
                  +- Exchange hashpartitioning(date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12, u.`uid`apache#47, u.`suid`apache#48, gid#44, 200), true, [id=apache#241]
                     +- *(2) HashAggregate(keys=[date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12, u.`uid`apache#47, u.`suid`apache#48, gid#44], functions=[])
                        +- *(2) Filter (CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46 = iOS)
                           +- *(2) Expand [ArrayBuffer(date_trunc(year, cast(dt#9 as timestamp), Some(Etc/GMT+7)), CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END, vs#12, uid#7, null, 1), ArrayBuffer(date_trunc(year, cast(dt#9 as timestamp), Some(Etc/GMT+7)), CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END, vs#12, null, suid#10, 2)], [date_trunc('year', CAST(u.`dt` AS TIMESTAMP))apache#45, CASE WHEN (h.`pid` = 3) THEN 'iOS' WHEN (h.`pid` = 4) THEN 'Android' ELSE 'Other' END#46, vs#12, u.`uid`apache#47, u.`suid`apache#48, gid#44]
                              +- *(2) Project [uid#7, dt#9, suid#10, pid#11, vs#12]
                                 +- *(2) BroadcastHashJoin [uid#7], [uid#13], Inner, BuildRight
                                    :- *(2) Project [uid#7, dt#9, suid#10]
                                    :  +- *(2) Filter isnotnull(uid#7)
                                    :     +- *(2) ColumnarToRow
                                    :        +- FileScan parquet default.t1[uid#7,dt#9,suid#10] Batched: true, DataFilters: [isnotnull(uid#7)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-bin-hadoop3.2/spark-warehouse/t1], PartitionFilters: [], PushedFilters: [IsNotNull(uid)], ReadSchema: struct<uid:int,dt:date,suid:int>
                                    +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[2, int, true] as bigint))), [id=apache#233]
                                       +- *(1) Project [pid#11, vs#12, uid#13]
                                          +- *(1) Filter isnotnull(uid#13)
                                             +- *(1) ColumnarToRow
                                                +- FileScan parquet default.t2[pid#11,vs#12,uid#13] Batched: true, DataFilters: [isnotnull(uid#13)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-bin-hadoop3.2/spark-warehouse/t2], PartitionFilters: [], PushedFilters: [IsNotNull(uid)], ReadSchema: struct<pid:int,vs:int,uid:int>
```

Plan. after. this pr. :
```
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[years#0, appversion#2], functions=[sum(uusers#3L)], output=[years#0, appversion#2, users#5L])
   +- Exchange hashpartitioning(years#0, appversion#2, 5), true, [id=apache#71]
      +- HashAggregate(keys=[years#0, appversion#2], functions=[partial_sum(uusers#3L)], output=[years#0, appversion#2, sum#22L])
         +- HashAggregate(keys=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12], functions=[count(distinct uid#7)], output=[years#0, appversion#2, uusers#3L])
            +- Exchange hashpartitioning(date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, 5), true, [id=apache#67]
               +- HashAggregate(keys=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12], functions=[partial_count(distinct uid#7)], output=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, count#27L])
                  +- HashAggregate(keys=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, uid#7], functions=[], output=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, uid#7])
                     +- Exchange hashpartitioning(date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, uid#7, 5), true, [id=apache#63]
                        +- HashAggregate(keys=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles)) AS date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END AS CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, uid#7], functions=[], output=[date_trunc(year, cast(dt#9 as timestamp), Some(America/Los_Angeles))#23, CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END#24, vs#12, uid#7])
                           +- Project [uid#7, dt#9, pid#11, vs#12]
                              +- BroadcastHashJoin [uid#7], [uid#13], Inner, BuildRight, false
                                 :- Filter isnotnull(uid#7)
                                 :  +- FileScan parquet default.t1[uid#7,dt#9] Batched: true, DataFilters: [isnotnull(uid#7)], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/4l/7_c5c97s1_gb0d9_d6shygx00000gn/T/warehouse-c069d87..., PartitionFilters: [], PushedFilters: [IsNotNull(uid)], ReadSchema: struct<uid:int,dt:date>
                                 +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[2, int, false] as bigint)),false), [id=apache#58]
                                    +- Filter ((CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END = iOS) AND isnotnull(uid#13))
                                       +- FileScan parquet default.t2[pid#11,vs#12,uid#13] Batched: true, DataFilters: [(CASE WHEN (pid#11 = 3) THEN iOS WHEN (pid#11 = 4) THEN Android ELSE Other END = iOS), isnotnull..., Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/4l/7_c5c97s1_gb0d9_d6shygx00000gn/T/warehouse-c069d87..., PartitionFilters: [], PushedFilters: [IsNotNull(uid)], ReadSchema: struct<pid:int,vs:int,uid:int>

```

### Why are the changes needed?
Improve  performance, filter more data.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added UT

Closes apache#30278 from AngersZhuuuu/SPARK-33302.

Authored-by: angerszhu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
BryanCutler pushed a commit that referenced this pull request Apr 21, 2021
…join (build right side)

### What changes were proposed in this pull request?

This PR is to add code-gen support for left semi / left anti BroadcastNestedLoopJoin (build side is right side). The execution code path for build left side cannot fit into whole stage code-gen framework, so only add the code-gen for build right side here.

Reference: the iterator (non-code-gen) code path is `BroadcastNestedLoopJoinExec.leftExistenceJoin()` with `BuildRight`.

### Why are the changes needed?

Improve query CPU performance.
Tested with a simple query:

```
val N = 20 << 20
val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left semi broadcast nested loop join", N) {
  park.range(N).selectExpr(s"id as k1").join(
    dim, col("k1") + 1 <= col("k2"), "left_semi")
}
```

Seeing 5x run time improvement:

```
Running benchmark: left semi broadcast nested loop join
  Running case: left semi broadcast nested loop join codegen off
  Stopped after 2 iterations, 6958 ms
  Running case: left semi broadcast nested loop join codegen on
  Stopped after 5 iterations, 3383 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
left semi broadcast nested loop join:             Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------------------------------------------------------
left semi broadcast nested loop join codegen off           3434           3479          65          6.1         163.7       1.0X
left semi broadcast nested loop join codegen on             672            677           5         31.2          32.1       5.1X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Changed existing unit test in `ExistenceJoinSuite.scala` to cover all code paths:
* left semi/anti + empty right side + empty condition
* left semi/anti + non-empty right side + empty condition
* left semi/anti + right side + non-empty condition

Added unit test in `WholeStageCodegenSuite.scala` to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.

Example query:

```
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_semi").explain("codegen")
```

Example generated code (`bnlj_doConsume_0` method):
This is for left semi join. The generated code for left anti join is mostly to be same as here, except L55 to be `if (bnlj_findMatchedRow_0 == false) {`.
```
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) Project [id#0L AS k1#2L]
+- *(2) BroadcastNestedLoopJoin BuildRight, LeftSemi, ((id#0L + 1) <= k2#6L)
   :- *(2) Range (0, 4, step=1, splits=2)
   +- BroadcastExchange IdentityBroadcastMode, [id=#23]
      +- *(1) Project [id#4L AS k2#6L]
         +- *(1) Range (0, 3, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 031 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     boolean bnlj_findMatchedRow_0 = false;
/* 038 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */
/* 041 */       long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 042 */
/* 043 */       long bnlj_value_3 = -1L;
/* 044 */
/* 045 */       bnlj_value_3 = bnlj_expr_0_0 + 1L;
/* 046 */
/* 047 */       boolean bnlj_value_2 = false;
/* 048 */       bnlj_value_2 = bnlj_value_3 <= bnlj_value_1;
/* 049 */       if (!(false || !bnlj_value_2))
/* 050 */       {
/* 051 */         bnlj_findMatchedRow_0 = true;
/* 052 */         break;
/* 053 */       }
/* 054 */     }
/* 055 */     if (bnlj_findMatchedRow_0 == true) {
/* 056 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 057 */
/* 058 */       // common sub-expressions
/* 059 */
/* 060 */       range_mutableStateArray_0[3].reset();
/* 061 */
/* 062 */       range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 063 */       append((range_mutableStateArray_0[3].getRow()).copy());
/* 064 */
/* 065 */     }
/* 066 */
/* 067 */   }
/* 068 */
/* 069 */   private void initRange(int idx) {
/* 070 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 071 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 072 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 073 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 074 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 075 */     long partitionEnd;
/* 076 */
/* 077 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 078 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 079 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 080 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 081 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 082 */     } else {
/* 083 */       range_nextIndex_0 = st.longValue();
/* 084 */     }
/* 085 */     range_batchEnd_0 = range_nextIndex_0;
/* 086 */
/* 087 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 088 */     .multiply(step).add(start);
/* 089 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 090 */       partitionEnd = Long.MAX_VALUE;
/* 091 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 092 */       partitionEnd = Long.MIN_VALUE;
/* 093 */     } else {
/* 094 */       partitionEnd = end.longValue();
/* 095 */     }
/* 096 */
/* 097 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 098 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 099 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 100 */     if (range_numElementsTodo_0 < 0) {
/* 101 */       range_numElementsTodo_0 = 0;
/* 102 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 103 */       range_numElementsTodo_0++;
/* 104 */     }
/* 105 */   }
/* 106 */
/* 107 */   protected void processNext() throws java.io.IOException {
/* 108 */     // initialize Range
/* 109 */     if (!range_initRange_0) {
/* 110 */       range_initRange_0 = true;
/* 111 */       initRange(partitionIndex);
/* 112 */     }
/* 113 */
/* 114 */     while (true) {
/* 115 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 116 */         long range_nextBatchTodo_0;
/* 117 */         if (range_numElementsTodo_0 > 1000L) {
/* 118 */           range_nextBatchTodo_0 = 1000L;
/* 119 */           range_numElementsTodo_0 -= 1000L;
/* 120 */         } else {
/* 121 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 122 */           range_numElementsTodo_0 = 0;
/* 123 */           if (range_nextBatchTodo_0 == 0) break;
/* 124 */         }
/* 125 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 126 */       }
/* 127 */
/* 128 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 129 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 130 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 131 */
/* 132 */         bnlj_doConsume_0(range_value_0);
/* 133 */
/* 134 */         if (shouldStop()) {
/* 135 */           range_nextIndex_0 = range_value_0 + 1L;
/* 136 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 137 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 138 */           return;
/* 139 */         }
/* 140 */
/* 141 */       }
/* 142 */       range_nextIndex_0 = range_batchEnd_0;
/* 143 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 144 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 145 */       range_taskContext_0.killTaskIfInterrupted();
/* 146 */     }
/* 147 */   }
/* 148 */
/* 149 */ }
```

Closes apache#31874 from c21/code-semi-anti.

Authored-by: Cheng Su <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants