diff --git a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R index 97972753a78fa..16d93763ff038 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R +++ b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R @@ -312,4 +312,22 @@ test_that("Arrow optimization - unsupported types", { }) }) +test_that("SPARK-32478: gapply() Arrow optimization - error message for schema mismatch", { + skip_if_not_installed("arrow") + df <- createDataFrame(list(list(a = 1L, b = "a"))) + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", "true") + tryCatch({ + expect_error( + count(gapply(df, "a", function(key, group) { group }, structType("a int, b int"))), + "expected IntegerType, IntegerType, got IntegerType, StringType") + }, + finally = { + callJMethod(conf, "set", "spark.sql.execution.arrow.sparkr.enabled", arrowEnabled) + }) +}) + sparkR.session.stop() diff --git a/docs/sparkr.md b/docs/sparkr.md index fa1bb1b851815..05310f89f278d 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -681,12 +681,12 @@ The current supported minimum version is 1.0.0; however, this might change betwe Arrow optimization is available when converting a Spark DataFrame to an R DataFrame using the call `collect(spark_df)`, when creating a Spark DataFrame from an R DataFrame with `createDataFrame(r_df)`, when applying an R native function to each partition via `dapply(...)` and when applying an R native function to grouped data via `gapply(...)`. -To use Arrow when executing these calls, users need to first set the Spark configuration ‘spark.sql.execution.arrow.sparkr.enabled’ -to ‘true’. This is disabled by default. +To use Arrow when executing these, users need to set the Spark configuration ‘spark.sql.execution.arrow.sparkr.enabled’ +to ‘true’ first. This is disabled by default. -In addition, optimizations enabled by ‘spark.sql.execution.arrow.sparkr.enabled’ could fallback automatically to non-Arrow optimization -implementation if an error occurs before the actual computation within Spark during converting a Spark DataFrame to/from an R -DataFrame. +Whether the optimization is enabled or not, SparkR produces the same results. In addition, the conversion +between Spark DataFrame and R DataFrame falls back automatically to non-Arrow optimization implementation +when the optimization fails for any reasons before the actual computation.
{% highlight r %} @@ -713,9 +713,9 @@ collect(gapply(spark_df, {% endhighlight %}
-Using the above optimizations with Arrow will produce the same results as when Arrow is not enabled. Note that even with Arrow, -`collect(spark_df)` results in the collection of all records in the DataFrame to the driver program and should be done on a -small subset of the data. +Note that even with Arrow, `collect(spark_df)` results in the collection of all records in the DataFrame to +the driver program and should be done on a small subset of the data. In addition, the specified output schema +in `gapply(...)` and `dapply(...)` should be matched to the R DataFrame's returned by the given function. ## Supported SQL Types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 4b2d4195ee906..c08db132c946f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -567,7 +567,14 @@ case class FlatMapGroupsInRWithArrowExec( // binary in a batch due to the limitation of R API. See also ARROW-4512. val columnarBatchIter = runner.compute(groupedByRKey, -1) val outputProject = UnsafeProjection.create(output, output) - columnarBatchIter.flatMap(_.rowIterator().asScala).map(outputProject) + val outputTypes = StructType.fromAttributes(output).map(_.dataType) + + columnarBatchIter.flatMap { batch => + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from gapply(): " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") + batch.rowIterator().asScala + }.map(outputProject) } } }