diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index 0a83e4543f5ec..bc8bb47812242 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -17,6 +17,8 @@ syntax = 'proto3'; +import "spark/connect/expressions.proto"; +import "spark/connect/relations.proto"; import "spark/connect/types.proto"; package spark.connect; @@ -29,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto"; message Command { oneof command_type { CreateScalarFunction create_function = 1; + WriteOperation write_operation = 2; } } @@ -62,3 +65,39 @@ message CreateScalarFunction { FUNCTION_LANGUAGE_SCALA = 3; } } + +// As writes are not directly handled during analysis and planning, they are modeled as commands. +message WriteOperation { + // The output of the `input` relation will be persisted according to the options. + Relation input = 1; + // Format value according to the Spark documentation. Examples are: text, parquet, delta. + string source = 2; + // The destination of the write operation must be either a path or a table. + oneof save_type { + string path = 3; + string table_name = 4; + } + SaveMode mode = 5; + // List of columns to sort the output by. + repeated string sort_column_names = 6; + // List of columns for partitioning. + repeated string partitioning_columns = 7; + // Optional bucketing specification. Bucketing must set the number of buckets and the columns + // to bucket by. + BucketBy bucket_by = 8; + // Optional list of configuration options. + map options = 9; + + message BucketBy { + repeated string bucket_column_names = 1; + int32 num_buckets = 2; + } + + enum SaveMode { + SAVE_MODE_UNSPECIFIED = 0; + SAVE_MODE_APPEND = 1; + SAVE_MODE_OVERWRITE = 2; + SAVE_MODE_ERROR_IF_EXISTS = 3; + SAVE_MODE_IGNORE = 4; + } +} \ No newline at end of file diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index ebc5cfe5b55b7..ae606a6a72edd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -24,10 +24,16 @@ import com.google.common.collect.{Lists, Maps} import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto -import org.apache.spark.sql.SparkSession +import org.apache.spark.connect.proto.WriteOperation +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types.StringType +final case class InvalidCommandInput( + private val message: String = "", + private val cause: Throwable = null) + extends Exception(message, cause) @Unstable @Since("3.4.0") @@ -40,6 +46,8 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) command.getCommandTypeCase match { case proto.Command.CommandTypeCase.CREATE_FUNCTION => handleCreateScalarFunction(command.getCreateFunction) + case proto.Command.CommandTypeCase.WRITE_OPERATION => + handleWriteOperation(command.getWriteOperation) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } @@ -74,4 +82,64 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) session.udf.registerPython(cf.getPartsList.asScala.head, udf) } + /** + * Transforms the write operation and executes it. + * + * The input write operation contains a reference to the input plan and transforms it to the + * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the + * parameters of the WriteOperation into the corresponding methods calls. + * + * @param writeOperation + */ + def handleWriteOperation(writeOperation: WriteOperation): Unit = { + // Transform the input plan into the logical plan. + val planner = new SparkConnectPlanner(writeOperation.getInput, session) + val plan = planner.transform() + // And create a Dataset from the plan. + val dataset = Dataset.ofRows(session, logicalPlan = plan) + + val w = dataset.write + if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { + w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode)) + } + + if (writeOperation.getOptionsCount > 0) { + writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) } + } + + if (writeOperation.getSortColumnNamesCount > 0) { + val names = writeOperation.getSortColumnNamesList.asScala + w.sortBy(names.head, names.tail.toSeq: _*) + } + + if (writeOperation.hasBucketBy) { + val op = writeOperation.getBucketBy + val cols = op.getBucketColumnNamesList.asScala + if (op.getNumBuckets <= 0) { + throw InvalidCommandInput( + s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.") + } + w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*) + } + + if (writeOperation.getPartitioningColumnsCount > 0) { + val names = writeOperation.getPartitioningColumnsList.asScala + w.partitionBy(names.toSeq: _*) + } + + if (writeOperation.getSource != null) { + w.format(writeOperation.getSource) + } + + writeOperation.getSaveTypeCase match { + case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath) + case proto.WriteOperation.SaveTypeCase.TABLE_NAME => + w.saveAsTable(writeOperation.getTableName) + case _ => + throw new UnsupportedOperationException( + "WriteOperation:SaveTypeCase not supported " + + s"${writeOperation.getSaveTypeCase.getNumber}") + } + } + } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 0db8ab9661074..4a932afd64cc0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -21,7 +21,9 @@ import scala.language.implicitConversions import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.connect.planner.DataTypeProtoConverter /** * A collection of implicit conversions that create a DSL for constructing connect protos. @@ -34,59 +36,106 @@ package object dsl { val identifier = CatalystSqlParser.parseMultipartIdentifier(s) def protoAttr: proto.Expression = - proto.Expression.newBuilder() + proto.Expression + .newBuilder() .setUnresolvedAttribute( - proto.Expression.UnresolvedAttribute.newBuilder() + proto.Expression.UnresolvedAttribute + .newBuilder() .addAllParts(identifier.asJava) .build()) .build() } implicit class DslExpression(val expr: proto.Expression) { - def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( - proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() - - def < (other: proto.Expression): proto.Expression = - proto.Expression.newBuilder().setUnresolvedFunction( - proto.Expression.UnresolvedFunction.newBuilder() - .addParts("<") - .addArguments(expr) - .addArguments(other) - ).build() + def as(alias: String): proto.Expression = proto.Expression + .newBuilder() + .setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)) + .build() + + def <(other: proto.Expression): proto.Expression = + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts("<") + .addArguments(expr) + .addArguments(other)) + .build() } implicit def intToLiteral(i: Int): proto.Expression = - proto.Expression.newBuilder().setLiteral( - proto.Expression.Literal.newBuilder().setI32(i) - ).build() + proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(i)) + .build() + } + + object commands { // scalastyle:ignore + implicit class DslCommands(val logicalPlan: proto.Relation) { + def write( + format: Option[String] = None, + path: Option[String] = None, + tableName: Option[String] = None, + mode: Option[String] = None, + sortByColumns: Seq[String] = Seq.empty, + partitionByCols: Seq[String] = Seq.empty, + bucketByCols: Seq[String] = Seq.empty, + numBuckets: Option[Int] = None): proto.Command = { + val writeOp = proto.WriteOperation.newBuilder() + format.foreach(writeOp.setSource(_)) + + mode + .map(SaveMode.valueOf(_)) + .map(DataTypeProtoConverter.toSaveModeProto(_)) + .foreach(writeOp.setMode(_)) + + if (tableName.nonEmpty) { + tableName.foreach(writeOp.setTableName(_)) + } else { + path.foreach(writeOp.setPath(_)) + } + sortByColumns.foreach(writeOp.addSortColumnNames(_)) + partitionByCols.foreach(writeOp.addPartitioningColumns(_)) + + if (numBuckets.nonEmpty && bucketByCols.nonEmpty) { + val op = proto.WriteOperation.BucketBy.newBuilder() + numBuckets.foreach(op.setNumBuckets(_)) + bucketByCols.foreach(op.addBucketColumnNames(_)) + writeOp.setBucketBy(op.build()) + } + writeOp.setInput(logicalPlan) + proto.Command.newBuilder().setWriteOperation(writeOp.build()).build() + } + } } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { proto.Relation.newBuilder().setProject( - proto.Project.newBuilder() - .setInput(logicalPlan) - .addAllExpressions(exprs.toIterable.asJava) - .build() - ).build() + proto.Project.newBuilder() + .setInput(logicalPlan) + .addAllExpressions(exprs.toIterable.asJava) + .build()) + .build() } def where(condition: proto.Expression): proto.Relation = { proto.Relation.newBuilder() .setFilter( proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) - ).build() + ).build() } - def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, condition: Option[proto.Expression] = None): proto.Relation = { val relation = proto.Relation.newBuilder() val join = proto.Join.newBuilder() - join.setLeft(logicalPlan) + join + .setLeft(logicalPlan) .setRight(otherPlan) .setJoinType(joinType) if (condition.isDefined) { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index b31855bfca993..a0a5ea82d14ce 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.types.{DataType, IntegerType, StringType} /** @@ -43,4 +44,28 @@ object DataTypeProtoConverter { throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + + def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { + mode match { + case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append + case proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE => SaveMode.Ignore + case proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE => SaveMode.Overwrite + case proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS => SaveMode.ErrorIfExists + case _ => + throw new IllegalArgumentException( + s"Cannot convert from WriteOperaton.SaveMode to Spark SaveMode: ${mode.getNumber}") + } + } + + def toSaveModeProto(mode: SaveMode): proto.WriteOperation.SaveMode = { + mode match { + case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND + case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE + case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE + case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS + case _ => + throw new IllegalArgumentException( + s"Cannot convert from SaveMode to WriteOperation.SaveMode: ${mode.name()}") + } + } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala new file mode 100644 index 0000000000000..e5ca670e4ddf5 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import java.nio.file.{Files, Paths} + +import org.apache.spark.SparkClassNotFoundException +import org.apache.spark.connect.proto +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner} +import org.apache.spark.sql.connect.dsl.commands._ +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} + +class SparkConnectCommandPlannerSuite + extends SQLTestUtils + with SparkConnectPlanTest + with SharedSparkSession { + + lazy val localRelation = createLocalRelationProto(Seq($"id".int)) + + def transform(cmd: proto.Command): Unit = { + new SparkConnectCommandPlanner(spark, cmd).process() + } + + test("Writes fails without path or table") { + assertThrows[UnsupportedOperationException] { + transform(localRelation.write()) + } + } + + test("Write fails with unknown table - AnalysisException") { + val cmd = readRel.write(tableName = Some("dest")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with partitions") { + val cmd = localRelation.write( + tableName = Some("testtable"), + format = Some("parquet"), + partitionByCols = Seq("noid")) + assertThrows[AnalysisException] { + transform(cmd) + } + } + + test("Write with invalid bucketBy configuration") { + val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0)) + assertThrows[InvalidCommandInput] { + transform(cmd) + } + } + + test("Write to Path") { + withTempDir { f => + val cmd = localRelation.write( + format = Some("parquet"), + path = Some(f.getPath), + mode = Some("Overwrite")) + transform(cmd) + assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}") + } + } + + test("Write to Path with invalid input") { + // Wrong data source. + assertThrows[SparkClassNotFoundException]( + transform( + localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat")))) + + // Default data source not found. + assertThrows[SparkClassNotFoundException]( + transform(localRelation.write(path = Some("/tmp/tmppath")))) + } + + test("Write with sortBy") { + // Sort by existing column. + withTable("testtable") { + transform( + localRelation.write( + tableName = Some("testtable"), + format = Some("parquet"), + sortByColumns = Seq("id"), + bucketByCols = Seq("id"), + numBuckets = Some(10))) + } + + // Sort by non-existing column + assertThrows[AnalysisException]( + transform( + localRelation + .write( + tableName = Some("testtable"), + format = Some("parquet"), + sortByColumns = Seq("noid"), + bucketByCols = Seq("id"), + numBuckets = Some(10)))) + } + + test("Write to Table") { + withTable("testtable") { + val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable")) + transform(cmd) + // Check that we can find and drop the table. + spark.sql(s"select count(*) from testtable").collect() + } + } + + test("SaveMode conversion tests") { + assertThrows[IllegalArgumentException]( + DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED)) + + val combinations = Seq( + (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND), + (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE), + (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE), + (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS)) + combinations.foreach { a => + assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2) + assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1) + } + } + +} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 10e17f121f0e5..ba6995bfc5a82 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** @@ -43,6 +44,25 @@ trait SparkConnectPlanTest { .setNamedTable(proto.Read.NamedTable.newBuilder().addParts("table")) .build()) .build() + + /** + * Creates a local relation for testing purposes. The local relation is mapped to it's + * equivalent in Catalyst and can be easily used for planner testing. + * + * @param attrs + * @return + */ + def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { + val localRelationBuilder = proto.LocalRelation.newBuilder() + for (attr <- attrs) { + localRelationBuilder.addAttributes( + proto.Expression.QualifiedAttribute + .newBuilder() + .setName(attr.name) + .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) + } + proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() + } } /** diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 351cc70852a18..418965e78ebef 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -16,11 +16,9 @@ */ package org.apache.spark.sql.connect.planner -import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -77,12 +75,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) for ((t, y) <- Seq( - (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), - (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), - (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), - (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), - (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), - (JoinType.JOIN_TYPE_INNER, Inner))) { + (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), + (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), + (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), + (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), + (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), + (JoinType.JOIN_TYPE_INNER, Inner))) { val connectPlan3 = { import org.apache.spark.sql.connect.dsl.plans._ transform(connectTestRelation.join(connectTestRelation2, t)) @@ -110,16 +108,4 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } - - private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { - val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder() - .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)) - ) - } - proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() - } }