From c4ae79e420a99009c400661ecf0a6d469e7d1552 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 11 Oct 2022 11:51:49 +0200 Subject: [PATCH 01/13] Scala implementation for DataFrameWriter --- .../protobuf/spark/connect/commands.proto | 29 ++++ .../command/SparkConnectCommandPlanner.scala | 74 ++++++++- .../spark/sql/connect/dsl/package.scala | 56 ++++++- .../connect/planner/SparkConnectPlanner.scala | 24 +-- .../SparkConnectCommandPlannerSuite.scala | 155 ++++++++++++++++++ .../planner/SparkConnectPlannerSuite.scala | 32 +++- .../planner/SparkConnectProtoSuite.scala | 14 -- 7 files changed, 342 insertions(+), 42 deletions(-) create mode 100644 connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index 0a83e4543f5ec..b653f7df23499 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -17,6 +17,8 @@ syntax = 'proto3'; +import "spark/connect/expressions.proto"; +import "spark/connect/relations.proto"; import "spark/connect/types.proto"; package spark.connect; @@ -29,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto"; message Command { oneof command_type { CreateScalarFunction create_function = 1; + WriteOperation write_operation = 2; } } @@ -62,3 +65,29 @@ message CreateScalarFunction { FUNCTION_LANGUAGE_SCALA = 3; } } + +// As writes are not directly handled during analysis and planning, they are modeled as commands. +message WriteOperation { + Relation input = 1; + string format = 2; + + oneof save_type { + string path = 3; + string table_name = 4; + } + string mode = 5; + repeated string sortColumnNames = 6; + repeated string partitionByColumns = 7; + BucketBy bucketBy = 8; + repeated WriteOptions options = 9; + + message BucketBy { + repeated string columns = 1; + int32 bucketCount = 2; + } + + message WriteOptions { + string key = 1; + string value = 2; + } +} \ No newline at end of file diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index ebc5cfe5b55b7..40bbbaecf6792 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -24,10 +24,16 @@ import com.google.common.collect.{Lists, Maps} import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto -import org.apache.spark.sql.SparkSession +import org.apache.spark.connect.proto.WriteOperation +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types.StringType +final case class InvalidCommandInput( + private val message: String = "", + private val cause: Throwable = None.orNull) + extends Exception(message, cause) @Unstable @Since("3.4.0") @@ -40,6 +46,8 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) command.getCommandTypeCase match { case proto.Command.CommandTypeCase.CREATE_FUNCTION => handleCreateScalarFunction(command.getCreateFunction) + case proto.Command.CommandTypeCase.WRITE_OPERATION => + handleWriteOperation(command.getWriteOperation) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } @@ -47,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it - * is possible to remotely serialize a Python function and execute it on the Spark cluster. - * If the Python version on the client and server diverge, the execution of the function that - * is serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it is + * possible to remotely serialize a Python function and execute it on the Spark cluster. If the + * Python version on the client and server diverge, the execution of the function that is + * serialized will most likely fail. * * @param cf */ @@ -74,4 +82,60 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) session.udf.registerPython(cf.getPartsList.asScala.head, udf) } + /** + * Transforms the write operation and executes it. + * + * The input write operation contains a reference to the input plan and transforms it to the + * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the + * parameters of the WriteOperation into the corresponding methods calls. + * + * @param writeOperation + */ + def handleWriteOperation(writeOperation: WriteOperation): Unit = { + // Transform the input plan into the logical plan. + val planner = new SparkConnectPlanner(writeOperation.getInput, session) + val plan = planner.transform() + // And create a Dataset from the plan. + val dataset = Dataset.ofRows(session, logicalPlan = plan) + + val w = dataset.write + if (writeOperation.getOptionsCount > 0) { + writeOperation.getOptionsList.asScala.foreach(x => w.option(x.getKey, x.getValue)) + } + + if (writeOperation.getSortColumnNamesCount > 0) { + val names = writeOperation.getSortColumnNamesList.asScala + w.sortBy(names.head, names.tail.toSeq: _*) + } + + if (writeOperation.hasBucketBy) { + val op = writeOperation.getBucketBy + val cols = op.getColumnsList.asScala + if (op.getBucketCount <= 0) { + throw InvalidCommandInput( + s"BucketBy must specify a bucket count > 0, received ${op.getBucketCount} instead.") + } + w.bucketBy(op.getBucketCount, cols.head, cols.tail.toSeq: _*) + } + + if (writeOperation.getPartitionByColumnsCount > 0) { + val names = writeOperation.getPartitionByColumnsList.asScala + w.partitionBy(names.toSeq: _*) + } + + if (writeOperation.getFormat != null) { + w.format(writeOperation.getFormat) + } + + writeOperation.getSaveTypeCase match { + case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath) + case proto.WriteOperation.SaveTypeCase.TABLE_NAME => + w.saveAsTable(writeOperation.getTableName) + case _ => + throw new UnsupportedOperationException( + s"WriteOperation:SaveTypeCase not supported " + + "${writeOperation.getSaveTypeCase.getNumber}") + } + } + } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 80d6e77c9fc45..16bc0def690a7 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -33,9 +33,11 @@ package object dsl { val identifier = CatalystSqlParser.parseMultipartIdentifier(s) def protoAttr: proto.Expression = - proto.Expression.newBuilder() + proto.Expression + .newBuilder() .setUnresolvedAttribute( - proto.Expression.UnresolvedAttribute.newBuilder() + proto.Expression.UnresolvedAttribute + .newBuilder() .addAllParts(identifier.asJava) .build()) .build() @@ -47,15 +49,53 @@ package object dsl { } } + object commands { // scalastyle:ignore + implicit class DslCommands(val logicalPlan: proto.Relation) { + def write( + format: Option[String] = None, + path: Option[String] = None, + tableName: Option[String] = None, + mode: Option[String] = None, + sortByColumns: Seq[String] = Seq.empty, + partitionByCols: Seq[String] = Seq.empty, + bucketByCols: Seq[String] = Seq.empty, + numBuckets: Option[Int] = None): proto.Command = { + val writeOp = proto.WriteOperation.newBuilder() + format.foreach(writeOp.setFormat(_)) + mode.foreach(writeOp.setMode(_)) + + if (tableName.nonEmpty) { + tableName.foreach(writeOp.setTableName(_)) + } else { + path.foreach(writeOp.setPath(_)) + } + sortByColumns.foreach(writeOp.addSortColumnNames(_)) + partitionByCols.foreach(writeOp.addPartitionByColumns(_)) + + if (numBuckets.nonEmpty && bucketByCols.nonEmpty) { + val op = proto.WriteOperation.BucketBy.newBuilder() + numBuckets.foreach(op.setBucketCount(_)) + bucketByCols.foreach(op.addColumns(_)) + writeOp.setBucketBy(op.build()) + } + writeOp.setInput(logicalPlan) + proto.Command.newBuilder().setWriteOperation(writeOp.build()).build() + } + } + } + object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation.newBuilder().setProject( - proto.Project.newBuilder() - .setInput(logicalPlan) - .addAllExpressions(exprs.toIterable.asJava) - .build() - ).build() + proto.Relation + .newBuilder() + .setProject( + proto.Project + .newBuilder() + .setInput(logicalPlan) + .addAllExpressions(exprs.toIterable.asJava) + .build()) + .build() } def join( diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ad95a6b516ab..2fe11cca60ead 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -60,7 +61,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -109,10 +111,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -141,7 +143,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -167,9 +169,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => + throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -188,7 +191,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun Proto representation of the function call. + * @param fun + * Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala new file mode 100644 index 0000000000000..a1614e4f18c98 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import java.io.File +import java.nio.file.{Files, Paths} +import java.util.UUID + +import org.apache.commons.io.FileUtils +import org.apache.spark.connect.proto +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} +import org.apache.spark.sql.connect.dsl.commands._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.{SparkClassNotFoundException, SparkFunSuite} + +class SparkConnectCommandPlannerSuite + extends SparkFunSuite + with SparkConnectPlanTest + with SharedSparkSession { + + lazy val localRelation = createLocalRelationProto(Seq($"id".int)) + + /** + * Returns a unique path name on every invocation. + * @return + */ + private def path(): String = s"/tmp/${UUID.randomUUID()}" + + /** + * Returns a unique valid table name indentifier on each invocation. + * @return + */ + private def table(): String = s"table${UUID.randomUUID().toString.replace("-", "")}" + + /** + * Helper method that takes a closure as an argument to handle cleanup of the resource created. + * @param thunk + */ + def withTable(thunk: String => Any): Unit = { + val name = table() + thunk(name) + spark.sql(s"drop table if exists ${name}") + } + + /** + * Helper method that takes a closure as an arugment and handles cleanup of the file system + * resource created. + * @param thunk + */ + def withPath(thunk: String => Any): Unit = { + val name = path() + thunk(name) + FileUtils.deleteDirectory(new File(name)) + } + + def transform(cmd: proto.Command): Unit = { + new SparkConnectCommandPlanner(spark, cmd).process() + } + + test("Writes fails without path or table") { + assertThrows[UnsupportedOperationException] { + transform(localRelation.write()) + } + } + + test("Write fails with unknown table - AnalysisException") { + val cmd = readRel.write(tableName = Some("dest")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with partitions") { + val name = table() + val cmd = localRelation.write( + tableName = Some(name), + format = Some("parquet"), + partitionByCols = Seq("noid")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with invalid bucketBy configuration") { + val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0)) + assertThrows[InvalidCommandInput] { + transform(cmd) + } + } + + test("Write to Path") { + withPath { name => + val cmd = localRelation.write(format = Some("parquet"), path = Some(name)) + transform(cmd) + assert(Files.exists(Paths.get(name)), s"Output file must exist: ${name}") + } + } + + test("Write to Path with invalid input") { + // Wrong data source. + assertThrows[SparkClassNotFoundException]( + transform(localRelation.write(path = Some(path), format = Some("ThisAintNoFormat")))) + + // Default data source not found. + assertThrows[SparkClassNotFoundException](transform(localRelation.write(path = Some(path)))) + } + + test("Write with sortBy") { + // Sort by existing column. + transform( + localRelation.write( + tableName = Some(table), + format = Some("parquet"), + sortByColumns = Seq("id"), + bucketByCols = Seq("id"), + numBuckets = Some(10))) + + // Sort by non-existing column + assertThrows[AnalysisException]( + transform( + localRelation + .write( + tableName = Some(table), + format = Some("parquet"), + sortByColumns = Seq("noid"), + bucketByCols = Seq("id"), + numBuckets = Some(10)))) + } + + test("Write to Table") { + withTable { name => + val cmd = localRelation.write(format = Some("parquet"), tableName = Some(name)) + transform(cmd) + // Check that we can find and drop the table. + spark.sql(s"select count(*) from ${name}").collect() + } + } +} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 10e17f121f0e5..0adfcfc0b12bc 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** @@ -43,6 +44,23 @@ trait SparkConnectPlanTest { .setNamedTable(proto.Read.NamedTable.newBuilder().addParts("table")) .build()) .build() + + /** + * Creates a local relation for testing purposes. The local relation is mapped to it's + * equivalent in Catalyst and can be easily used for planner testing. + * + * @param attrs + * @return + */ + def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { + val localRelationBuilder = proto.LocalRelation.newBuilder() + // TODO: set data types for each local relation attribute one proto supports data type. + for (attr <- attrs) { + localRelationBuilder.addAttributes( + proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()) + } + proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() + } } /** @@ -88,16 +106,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read.newBuilder() + val readWithTable = proto.Read + .newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project.newBuilder() + proto.Project + .newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression.newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() - ).build() + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()) + .build()) + .build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project") diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 510b54cd25084..cac85f9d20ccc 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -16,11 +16,9 @@ */ package org.apache.spark.sql.connect.planner -import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -99,16 +97,4 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } - - private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { - val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder() - .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)) - ) - } - proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() - } } From 6d152e2b338f10a8dc841054362df149510daa93 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 11 Oct 2022 16:02:17 +0200 Subject: [PATCH 02/13] format --- .../sql/connect/planner/SparkConnectCommandPlannerSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala index a1614e4f18c98..8f0b3917286e1 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -22,13 +22,14 @@ import java.nio.file.{Files, Paths} import java.util.UUID import org.apache.commons.io.FileUtils + +import org.apache.spark.{SparkClassNotFoundException, SparkFunSuite} import org.apache.spark.connect.proto import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.{SparkClassNotFoundException, SparkFunSuite} class SparkConnectCommandPlannerSuite extends SparkFunSuite From cdf41d64ccb28af1cf94e854f9640c4eca61ae8c Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 11 Oct 2022 18:13:44 +0200 Subject: [PATCH 03/13] fix --- .../sql/connect/planner/SparkConnectPlannerSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 0adfcfc0b12bc..67518f3bdb172 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -54,10 +54,12 @@ trait SparkConnectPlanTest { */ def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - // TODO: set data types for each local relation attribute one proto supports data type. for (attr <- attrs) { localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()) + proto.Expression.QualifiedAttribute + .newBuilder() + .setName(attr.name) + .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } From 139873e670eac3adb81a133200033033dcbb5e75 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 12 Oct 2022 18:14:04 +0200 Subject: [PATCH 04/13] testing and stuff --- .../protobuf/spark/connect/commands.proto | 22 +++++++++++----- .../command/SparkConnectCommandPlanner.scala | 8 ++++-- .../spark/sql/connect/dsl/package.scala | 21 +++++++++++----- .../planner/DataTypeProtoConverter.scala | 25 +++++++++++++++++++ .../SparkConnectCommandPlannerSuite.scala | 18 ++++++++++++- .../planner/SparkConnectProtoSuite.scala | 22 ++++++++++------ 6 files changed, 94 insertions(+), 22 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index b653f7df23499..c01237b57f0c3 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -68,26 +68,36 @@ message CreateScalarFunction { // As writes are not directly handled during analysis and planning, they are modeled as commands. message WriteOperation { + // The output of the `input` relation will be persisted according to the options. Relation input = 1; + // Format value according to the Spark documentation. Examples are: text, parquet, delta. string format = 2; - + // The destination of the write operation must be either a path or a table. oneof save_type { string path = 3; string table_name = 4; } - string mode = 5; + SaveMode mode = 5; + // List of columns to sort the output by. repeated string sortColumnNames = 6; + // List of columns for partitioning. repeated string partitionByColumns = 7; + // Optional bucketing specification. Bucketing must set the number of buckets and the columns + // to bucket by. BucketBy bucketBy = 8; - repeated WriteOptions options = 9; + // Optional list of configuration options. + map options = 9; message BucketBy { repeated string columns = 1; int32 bucketCount = 2; } - message WriteOptions { - string key = 1; - string value = 2; + enum SaveMode { + SAVE_MODE_UNSPECIFIED = 0; + SAVE_MODE_APPEND = 1; + SAVE_MODE_OVERWRITE = 2; + SAVE_MODE_ERROR_IF_EXISTS = 3; + SAVE_MODE_IGNORE = 4; } } \ No newline at end of file diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 40bbbaecf6792..21836a528adc5 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types.StringType @@ -99,8 +99,12 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) val dataset = Dataset.ofRows(session, logicalPlan = plan) val w = dataset.write + if (writeOperation.getMode != null) { + w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode)) + } + if (writeOperation.getOptionsCount > 0) { - writeOperation.getOptionsList.asScala.foreach(x => w.option(x.getKey, x.getValue)) + writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) } } if (writeOperation.getSortColumnNamesCount > 0) { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 16bc0def690a7..91297c3ea065b 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -20,7 +20,9 @@ import scala.collection.JavaConverters._ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.connect.planner.DataTypeProtoConverter /** * A collection of implicit conversions that create a DSL for constructing connect protos. @@ -44,8 +46,10 @@ package object dsl { } implicit class DslExpression(val expr: proto.Expression) { - def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( - proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() + def as(alias: String): proto.Expression = proto.Expression + .newBuilder() + .setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .build() } } @@ -62,7 +66,11 @@ package object dsl { numBuckets: Option[Int] = None): proto.Command = { val writeOp = proto.WriteOperation.newBuilder() format.foreach(writeOp.setFormat(_)) - mode.foreach(writeOp.setMode(_)) + + mode + .map(SaveMode.valueOf(_)) + .map(DataTypeProtoConverter.toSaveModeProto(_)) + .foreach(writeOp.setMode(_)) if (tableName.nonEmpty) { tableName.foreach(writeOp.setTableName(_)) @@ -104,7 +112,8 @@ package object dsl { condition: Option[proto.Expression] = None): proto.Relation = { val relation = proto.Relation.newBuilder() val join = proto.Join.newBuilder() - join.setLeft(logicalPlan) + join + .setLeft(logicalPlan) .setRight(otherPlan) .setJoinType(joinType) if (condition.isDefined) { @@ -113,8 +122,8 @@ package object dsl { relation.setJoin(join).build() } - def groupBy( - groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + def groupBy(groupingExprs: proto.Expression*)( + aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index b31855bfca993..a0a5ea82d14ce 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.types.{DataType, IntegerType, StringType} /** @@ -43,4 +44,28 @@ object DataTypeProtoConverter { throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + + def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { + mode match { + case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append + case proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE => SaveMode.Ignore + case proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE => SaveMode.Overwrite + case proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS => SaveMode.ErrorIfExists + case _ => + throw new IllegalArgumentException( + s"Cannot convert from WriteOperaton.SaveMode to Spark SaveMode: ${mode.getNumber}") + } + } + + def toSaveModeProto(mode: SaveMode): proto.WriteOperation.SaveMode = { + mode match { + case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND + case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE + case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE + case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS + case _ => + throw new IllegalArgumentException( + s"Cannot convert from SaveMode to WriteOperation.SaveMode: ${mode.name()}") + } + } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala index 8f0b3917286e1..699537cb3467f 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -25,7 +25,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkClassNotFoundException, SparkFunSuite} import org.apache.spark.connect.proto -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} import org.apache.spark.sql.connect.dsl.commands._ @@ -153,4 +153,20 @@ class SparkConnectCommandPlannerSuite spark.sql(s"select count(*) from ${name}").collect() } } + + test("SaveMode conversion tests") { + assertThrows[IllegalArgumentException]( + DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED)) + + val combinations = Seq( + (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND), + (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE), + (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE), + (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS)) + combinations.foreach { a => + assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2) + assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1) + } + } + } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index cac85f9d20ccc..6174bf1a0374f 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -19,7 +19,15 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ + FullOuter, + Inner, + LeftAnti, + LeftOuter, + LeftSemi, + PlanTest, + RightOuter +} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** @@ -64,12 +72,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) for ((t, y) <- Seq( - (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), - (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), - (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), - (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), - (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), - (JoinType.JOIN_TYPE_INNER, Inner))) { + (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), + (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), + (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), + (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), + (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), + (JoinType.JOIN_TYPE_INNER, Inner))) { val connectPlan3 = { import org.apache.spark.sql.connect.dsl.plans._ transform(connectTestRelation.join(connectTestRelation2, t)) From 36d320b52b042b9ad9981bc8480c2117912ac5a5 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 12 Oct 2022 18:39:00 +0200 Subject: [PATCH 05/13] fixing the test --- .../spark/sql/connect/command/SparkConnectCommandPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 21836a528adc5..1c36efdddf825 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -99,7 +99,7 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) val dataset = Dataset.ofRows(session, logicalPlan = plan) val w = dataset.write - if (writeOperation.getMode != null) { + if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode)) } From 7a3a16cc4f7ec1c72de5887ce5f7ca833070710f Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 12 Oct 2022 21:05:21 +0200 Subject: [PATCH 06/13] comments --- .../spark/sql/connect/command/SparkConnectCommandPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 1c36efdddf825..a76fccff1f302 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.StringType final case class InvalidCommandInput( private val message: String = "", - private val cause: Throwable = None.orNull) + private val cause: Throwable = null) extends Exception(message, cause) @Unstable From f681c72eb7dd5b2981b38ef40104757e9bf145fa Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Thu, 13 Oct 2022 22:28:24 +0200 Subject: [PATCH 07/13] comments --- .../spark/sql/connect/dsl/package.scala | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 4956b0b2c5595..b5688a482deae 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -47,23 +47,30 @@ package object dsl { } implicit class DslExpression(val expr: proto.Expression) { - def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( - proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() - - def < (other: proto.Expression): proto.Expression = - proto.Expression.newBuilder().setUnresolvedFunction( - proto.Expression.UnresolvedFunction.newBuilder() - .addParts("<") - .addArguments(expr) - .addArguments(other) - ).build() - - implicit def intToLiteral(i: Int): proto.Expression = - proto.Expression.newBuilder().setLiteral( - proto.Expression.Literal.newBuilder().setI32(i) - ).build() + def as(alias: String): proto.Expression = proto.Expression + .newBuilder() + .setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .build() + + def <(other: proto.Expression): proto.Expression = + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts("<") + .addArguments(expr) + .addArguments(other)) + .build() } + implicit def intToLiteral(i: Int): proto.Expression = + proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(i)) + .build() + } + object commands { // scalastyle:ignore implicit class DslCommands(val logicalPlan: proto.Relation) { def write( @@ -118,13 +125,12 @@ package object dsl { } def where(condition: proto.Expression): proto.Relation = { - proto.Relation.newBuilder() - .setFilter( - proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) - ).build() + proto.Relation + .newBuilder() + .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) + .build() } - def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, From 5af6799d744f7be42aa2d82b785bbb8da3c74d40 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 08:18:47 +0200 Subject: [PATCH 08/13] review comments --- .../protobuf/spark/connect/commands.proto | 12 +++--- .../command/SparkConnectCommandPlanner.scala | 20 ++++----- .../spark/sql/connect/dsl/package.scala | 8 ++-- .../SparkConnectCommandPlannerSuite.scala | 41 +++++-------------- 4 files changed, 30 insertions(+), 51 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index c01237b57f0c3..bc8bb47812242 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -71,7 +71,7 @@ message WriteOperation { // The output of the `input` relation will be persisted according to the options. Relation input = 1; // Format value according to the Spark documentation. Examples are: text, parquet, delta. - string format = 2; + string source = 2; // The destination of the write operation must be either a path or a table. oneof save_type { string path = 3; @@ -79,18 +79,18 @@ message WriteOperation { } SaveMode mode = 5; // List of columns to sort the output by. - repeated string sortColumnNames = 6; + repeated string sort_column_names = 6; // List of columns for partitioning. - repeated string partitionByColumns = 7; + repeated string partitioning_columns = 7; // Optional bucketing specification. Bucketing must set the number of buckets and the columns // to bucket by. - BucketBy bucketBy = 8; + BucketBy bucket_by = 8; // Optional list of configuration options. map options = 9; message BucketBy { - repeated string columns = 1; - int32 bucketCount = 2; + repeated string bucket_column_names = 1; + int32 num_buckets = 2; } enum SaveMode { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index a76fccff1f302..47d421a0359bf 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -114,21 +114,21 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) if (writeOperation.hasBucketBy) { val op = writeOperation.getBucketBy - val cols = op.getColumnsList.asScala - if (op.getBucketCount <= 0) { + val cols = op.getBucketColumnNamesList.asScala + if (op.getNumBuckets <= 0) { throw InvalidCommandInput( - s"BucketBy must specify a bucket count > 0, received ${op.getBucketCount} instead.") + s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.") } - w.bucketBy(op.getBucketCount, cols.head, cols.tail.toSeq: _*) + w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*) } - if (writeOperation.getPartitionByColumnsCount > 0) { - val names = writeOperation.getPartitionByColumnsList.asScala + if (writeOperation.getPartitioningColumnsCount > 0) { + val names = writeOperation.getPartitioningColumnsList.asScala w.partitionBy(names.toSeq: _*) } - if (writeOperation.getFormat != null) { - w.format(writeOperation.getFormat) + if (writeOperation.getSource != null) { + w.format(writeOperation.getSource) } writeOperation.getSaveTypeCase match { @@ -137,8 +137,8 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) w.saveAsTable(writeOperation.getTableName) case _ => throw new UnsupportedOperationException( - s"WriteOperation:SaveTypeCase not supported " - + "${writeOperation.getSaveTypeCase.getNumber}") + "WriteOperation:SaveTypeCase not supported " + + s"${writeOperation.getSaveTypeCase.getNumber}") } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index b5688a482deae..9fc543d07ee36 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -83,7 +83,7 @@ package object dsl { bucketByCols: Seq[String] = Seq.empty, numBuckets: Option[Int] = None): proto.Command = { val writeOp = proto.WriteOperation.newBuilder() - format.foreach(writeOp.setFormat(_)) + format.foreach(writeOp.setSource(_)) mode .map(SaveMode.valueOf(_)) @@ -96,12 +96,12 @@ package object dsl { path.foreach(writeOp.setPath(_)) } sortByColumns.foreach(writeOp.addSortColumnNames(_)) - partitionByCols.foreach(writeOp.addPartitionByColumns(_)) + partitionByCols.foreach(writeOp.addPartitioningColumns(_)) if (numBuckets.nonEmpty && bucketByCols.nonEmpty) { val op = proto.WriteOperation.BucketBy.newBuilder() - numBuckets.foreach(op.setBucketCount(_)) - bucketByCols.foreach(op.addColumns(_)) + numBuckets.foreach(op.setNumBuckets(_)) + bucketByCols.foreach(op.addBucketColumnNames(_)) writeOp.setBucketBy(op.build()) } writeOp.setInput(logicalPlan) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala index 699537cb3467f..36efd31f5590b 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -17,22 +17,19 @@ package org.apache.spark.sql.connect.planner -import java.io.File import java.nio.file.{Files, Paths} import java.util.UUID -import org.apache.commons.io.FileUtils - -import org.apache.spark.{SparkClassNotFoundException, SparkFunSuite} +import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} import org.apache.spark.sql.connect.dsl.commands._ -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} class SparkConnectCommandPlannerSuite - extends SparkFunSuite + extends SQLTestUtils with SparkConnectPlanTest with SharedSparkSession { @@ -50,27 +47,6 @@ class SparkConnectCommandPlannerSuite */ private def table(): String = s"table${UUID.randomUUID().toString.replace("-", "")}" - /** - * Helper method that takes a closure as an argument to handle cleanup of the resource created. - * @param thunk - */ - def withTable(thunk: String => Any): Unit = { - val name = table() - thunk(name) - spark.sql(s"drop table if exists ${name}") - } - - /** - * Helper method that takes a closure as an arugment and handles cleanup of the file system - * resource created. - * @param thunk - */ - def withPath(thunk: String => Any): Unit = { - val name = path() - thunk(name) - FileUtils.deleteDirectory(new File(name)) - } - def transform(cmd: proto.Command): Unit = { new SparkConnectCommandPlanner(spark, cmd).process() } @@ -107,10 +83,13 @@ class SparkConnectCommandPlannerSuite } test("Write to Path") { - withPath { name => - val cmd = localRelation.write(format = Some("parquet"), path = Some(name)) + withTempDir { f => + val cmd = localRelation.write( + format = Some("parquet"), + path = Some(f.getPath), + mode = Some("Overwrite")) transform(cmd) - assert(Files.exists(Paths.get(name)), s"Output file must exist: ${name}") + assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}") } } @@ -146,7 +125,7 @@ class SparkConnectCommandPlannerSuite } test("Write to Table") { - withTable { name => + withTable(table()) { name: String => val cmd = localRelation.write(format = Some("parquet"), tableName = Some(name)) transform(cmd) // Check that we can find and drop the table. From dc200c8690c0e7145a7c1f63194fbfdd80214fa5 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 08:22:15 +0200 Subject: [PATCH 09/13] formatting --- .../sql/connect/planner/SparkConnectProtoSuite.scala | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6eac6383273ed..418965e78ebef 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -19,15 +19,7 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{ - FullOuter, - Inner, - LeftAnti, - LeftOuter, - LeftSemi, - PlanTest, - RightOuter -} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** From 1d70250d8b7dd6793bcc1b0ba9965e68a9016ced Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 08:33:25 +0200 Subject: [PATCH 10/13] formatting --- .../command/SparkConnectCommandPlanner.scala | 8 +++---- .../connect/planner/SparkConnectPlanner.scala | 24 ++++++++----------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 47d421a0359bf..ae606a6a72edd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -55,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it is - * possible to remotely serialize a Python function and execute it on the Spark cluster. If the - * Python version on the client and server diverge, the execution of the function that is - * serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it + * is possible to remotely serialize a Python function and execute it on the Spark cluster. + * If the Python version on the client and server diverge, the execution of the function that + * is serialized will most likely fail. * * @param cf */ diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 2fe11cca60ead..5ad95a6b516ab 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -61,8 +60,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -111,10 +109,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -143,7 +141,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -169,10 +167,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => - throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -191,8 +188,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun - * Proto representation of the function call. + * @param fun Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { From 85f81c03958892bf491e176fef46c4cbc6044d44 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 09:37:01 +0200 Subject: [PATCH 11/13] reverting format --- .../spark/sql/connect/dsl/package.scala | 19 ++++----- .../SparkConnectCommandPlannerSuite.scala | 42 +++++++------------ .../planner/SparkConnectPlannerSuite.scala | 13 +++--- 3 files changed, 29 insertions(+), 45 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 9fc543d07ee36..4a932afd64cc0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -113,11 +113,8 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation - .newBuilder() - .setProject( - proto.Project - .newBuilder() + proto.Relation.newBuilder().setProject( + proto.Project.newBuilder() .setInput(logicalPlan) .addAllExpressions(exprs.toIterable.asJava) .build()) @@ -125,10 +122,10 @@ package object dsl { } def where(condition: proto.Expression): proto.Relation = { - proto.Relation - .newBuilder() - .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) - .build() + proto.Relation.newBuilder() + .setFilter( + proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) + ).build() } def join( @@ -147,8 +144,8 @@ package object dsl { relation.setJoin(join).build() } - def groupBy(groupingExprs: proto.Expression*)( - aggregateExprs: proto.Expression*): proto.Relation = { + def groupBy( + groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala index 36efd31f5590b..ea2ed6f6a1701 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.planner import java.nio.file.{Files, Paths} -import java.util.UUID import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto @@ -35,18 +34,6 @@ class SparkConnectCommandPlannerSuite lazy val localRelation = createLocalRelationProto(Seq($"id".int)) - /** - * Returns a unique path name on every invocation. - * @return - */ - private def path(): String = s"/tmp/${UUID.randomUUID()}" - - /** - * Returns a unique valid table name indentifier on each invocation. - * @return - */ - private def table(): String = s"table${UUID.randomUUID().toString.replace("-", "")}" - def transform(cmd: proto.Command): Unit = { new SparkConnectCommandPlanner(spark, cmd).process() } @@ -65,9 +52,8 @@ class SparkConnectCommandPlannerSuite } test("Write with partitions") { - val name = table() val cmd = localRelation.write( - tableName = Some(name), + tableName = Some("testtable"), format = Some("parquet"), partitionByCols = Seq("noid")) assertThrows[AnalysisException] { @@ -96,28 +82,32 @@ class SparkConnectCommandPlannerSuite test("Write to Path with invalid input") { // Wrong data source. assertThrows[SparkClassNotFoundException]( - transform(localRelation.write(path = Some(path), format = Some("ThisAintNoFormat")))) + transform( + localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat")))) // Default data source not found. - assertThrows[SparkClassNotFoundException](transform(localRelation.write(path = Some(path)))) + assertThrows[SparkClassNotFoundException]( + transform(localRelation.write(path = Some("/tmp/tmppath")))) } test("Write with sortBy") { // Sort by existing column. - transform( - localRelation.write( - tableName = Some(table), - format = Some("parquet"), - sortByColumns = Seq("id"), - bucketByCols = Seq("id"), - numBuckets = Some(10))) + withTable("testtable") { table: String => + transform( + localRelation.write( + tableName = Some(table), + format = Some("parquet"), + sortByColumns = Seq("id"), + bucketByCols = Seq("id"), + numBuckets = Some(10))) + } // Sort by non-existing column assertThrows[AnalysisException]( transform( localRelation .write( - tableName = Some(table), + tableName = Some("testtable"), format = Some("parquet"), sortByColumns = Seq("noid"), bucketByCols = Seq("id"), @@ -125,7 +115,7 @@ class SparkConnectCommandPlannerSuite } test("Write to Table") { - withTable(table()) { name: String => + withTable("testtable") { name: String => val cmd = localRelation.write(format = Some("parquet"), tableName = Some(name)) transform(cmd) // Check that we can find and drop the table. diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 67518f3bdb172..a708b1fe62eb7 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -108,19 +108,16 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read - .newBuilder() + val readWithTable = proto.Read.newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project - .newBuilder() + proto.Project.newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression - .newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()) - .build()) + proto.Expression.newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build() + ).build()) .build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) From 0dc03ac694bcc17c73c3ef3d634072a318e337c0 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 09:38:21 +0200 Subject: [PATCH 12/13] reverting format --- .../spark/sql/connect/planner/SparkConnectPlannerSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index a708b1fe62eb7..ba6995bfc5a82 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -116,9 +116,8 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( proto.Expression.newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build() - ).build()) - .build() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() + ).build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project") From a13f9e412feeb6c05bdc3e9e61d432f9d80965f5 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 14 Oct 2022 10:38:21 +0200 Subject: [PATCH 13/13] scala 2.13 fix --- .../planner/SparkConnectCommandPlannerSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala index ea2ed6f6a1701..e5ca670e4ddf5 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -92,10 +92,10 @@ class SparkConnectCommandPlannerSuite test("Write with sortBy") { // Sort by existing column. - withTable("testtable") { table: String => + withTable("testtable") { transform( localRelation.write( - tableName = Some(table), + tableName = Some("testtable"), format = Some("parquet"), sortByColumns = Seq("id"), bucketByCols = Seq("id"), @@ -115,11 +115,11 @@ class SparkConnectCommandPlannerSuite } test("Write to Table") { - withTable("testtable") { name: String => - val cmd = localRelation.write(format = Some("parquet"), tableName = Some(name)) + withTable("testtable") { + val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable")) transform(cmd) // Check that we can find and drop the table. - spark.sql(s"select count(*) from ${name}").collect() + spark.sql(s"select count(*) from testtable").collect() } }