Skip to content
39 changes: 39 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

syntax = 'proto3';

import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";

package spark.connect;
Expand All @@ -29,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
message Command {
oneof command_type {
CreateScalarFunction create_function = 1;
WriteOperation write_operation = 2;
}
}

Expand Down Expand Up @@ -62,3 +65,39 @@ message CreateScalarFunction {
FUNCTION_LANGUAGE_SCALA = 3;
}
}

// As writes are not directly handled during analysis and planning, they are modeled as commands.
message WriteOperation {
// The output of the `input` relation will be persisted according to the options.
Relation input = 1;
// Format value according to the Spark documentation. Examples are: text, parquet, delta.
string source = 2;
// The destination of the write operation must be either a path or a table.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in DF API, people can do df.write.format("jdbc").option("table", ...).save() , so the destination is neither path nor table. I think an optional table name is sufficient. If table name is not given, the destination will be figured out from write options (path is just one write option).

oneof save_type {
string path = 3;
string table_name = 4;
}
SaveMode mode = 5;
Copy link
Contributor

@cloud-fan cloud-fan Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added DataFrameWriterV2 because we believe SaveMode is a bad design. It's confusing if we write to a table, as there are so many options: create if not exists, create or replace, replace if exists, append if exists, overwrite data if exists, etc.

Anyway, we need to support save mode in the proto definition to support the existing DF API. If we want to support DataFrameWriterV2 in Spark connect client, we should probably have a new proto definition without save mode.

// List of columns to sort the output by.
repeated string sort_column_names = 6;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be part of the BucketBy

// List of columns for partitioning.
repeated string partitioning_columns = 7;
// Optional bucketing specification. Bucketing must set the number of buckets and the columns
// to bucket by.
BucketBy bucket_by = 8;
// Optional list of configuration options.
map<string, string> options = 9;

message BucketBy {
repeated string bucket_column_names = 1;
int32 num_buckets = 2;
}

enum SaveMode {
SAVE_MODE_UNSPECIFIED = 0;
SAVE_MODE_APPEND = 1;
SAVE_MODE_OVERWRITE = 2;
SAVE_MODE_ERROR_IF_EXISTS = 3;
SAVE_MODE_IGNORE = 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ import com.google.common.collect.{Lists, Maps}
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types.StringType

final case class InvalidCommandInput(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use IllegalArgumentException here? Or do you feel this needs its own specific exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to have a custom exception for when we rethrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a user-facing error, we should actually leverage errorframe work we have .. cc @gengliangwang @MaxGekk @itholic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to fix this as a follow up, does it make sense?

The errors are reported back through grpc. If you point me to the right base class I can fix it then.

private val message: String = "",
private val cause: Throwable = null)
extends Exception(message, cause)

@Unstable
@Since("3.4.0")
Expand All @@ -40,6 +46,8 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
handleCreateScalarFunction(command.getCreateFunction)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}
Expand Down Expand Up @@ -74,4 +82,64 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
session.udf.registerPython(cf.getPartsList.asScala.head, udf)
}

/**
* Transforms the write operation and executes it.
*
* The input write operation contains a reference to the input plan and transforms it to the
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
* parameters of the WriteOperation into the corresponding methods calls.
*
* @param writeOperation
*/
def handleWriteOperation(writeOperation: WriteOperation): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit weird to have this in the SparkPlanner node, but I guess this is the consequence of the builder() API we have in the DataFrameWriter.

@cloud-fan AFAIK you have been working on making writes more declarative (i.e. planned writes). Do you see a way to improve this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more than planned write. We need to create a logical plan for DF write, instead of putting implementation code in DF write APIs.

// Transform the input plan into the logical plan.
val planner = new SparkConnectPlanner(writeOperation.getInput, session)
val plan = planner.transform()
// And create a Dataset from the plan.
val dataset = Dataset.ofRows(session, logicalPlan = plan)

val w = dataset.write
if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
}

if (writeOperation.getOptionsCount > 0) {
writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
}

if (writeOperation.getSortColumnNamesCount > 0) {
val names = writeOperation.getSortColumnNamesList.asScala
w.sortBy(names.head, names.tail.toSeq: _*)
}

if (writeOperation.hasBucketBy) {
val op = writeOperation.getBucketBy
val cols = op.getBucketColumnNamesList.asScala
if (op.getNumBuckets <= 0) {
throw InvalidCommandInput(
s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
}
w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
}

if (writeOperation.getPartitioningColumnsCount > 0) {
val names = writeOperation.getPartitioningColumnsList.asScala
w.partitionBy(names.toSeq: _*)
}

if (writeOperation.getSource != null) {
w.format(writeOperation.getSource)
}

writeOperation.getSaveTypeCase match {
case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
w.saveAsTable(writeOperation.getTableName)
case _ =>
throw new UnsupportedOperationException(
"WriteOperation:SaveTypeCase not supported "
+ s"${writeOperation.getSaveTypeCase.getNumber}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.language.implicitConversions

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connect.planner.DataTypeProtoConverter

/**
* A collection of implicit conversions that create a DSL for constructing connect protos.
Expand All @@ -34,59 +36,106 @@ package object dsl {
val identifier = CatalystSqlParser.parseMultipartIdentifier(s)

def protoAttr: proto.Expression =
proto.Expression.newBuilder()
proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder()
proto.Expression.UnresolvedAttribute
.newBuilder()
.addAllParts(identifier.asJava)
.build())
.build()
}

implicit class DslExpression(val expr: proto.Expression) {
def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias(
proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build()

def < (other: proto.Expression): proto.Expression =
proto.Expression.newBuilder().setUnresolvedFunction(
proto.Expression.UnresolvedFunction.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other)
).build()
def as(alias: String): proto.Expression = proto.Expression
.newBuilder()
.setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.build()

def <(other: proto.Expression): proto.Expression =
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression.newBuilder().setLiteral(
proto.Expression.Literal.newBuilder().setI32(i)
).build()
proto.Expression
.newBuilder()
.setLiteral(proto.Expression.Literal.newBuilder().setI32(i))
.build()
}

object commands { // scalastyle:ignore
implicit class DslCommands(val logicalPlan: proto.Relation) {
def write(
format: Option[String] = None,
path: Option[String] = None,
tableName: Option[String] = None,
mode: Option[String] = None,
sortByColumns: Seq[String] = Seq.empty,
partitionByCols: Seq[String] = Seq.empty,
bucketByCols: Seq[String] = Seq.empty,
numBuckets: Option[Int] = None): proto.Command = {
val writeOp = proto.WriteOperation.newBuilder()
format.foreach(writeOp.setSource(_))

mode
.map(SaveMode.valueOf(_))
.map(DataTypeProtoConverter.toSaveModeProto(_))
.foreach(writeOp.setMode(_))

if (tableName.nonEmpty) {
tableName.foreach(writeOp.setTableName(_))
} else {
path.foreach(writeOp.setPath(_))
}
sortByColumns.foreach(writeOp.addSortColumnNames(_))
partitionByCols.foreach(writeOp.addPartitioningColumns(_))

if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
val op = proto.WriteOperation.BucketBy.newBuilder()
numBuckets.foreach(op.setNumBuckets(_))
bucketByCols.foreach(op.addBucketColumnNames(_))
writeOp.setBucketBy(op.build())
}
writeOp.setInput(logicalPlan)
proto.Command.newBuilder().setWriteOperation(writeOp.build()).build()
}
}
}

object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
def select(exprs: proto.Expression*): proto.Relation = {
proto.Relation.newBuilder().setProject(
proto.Project.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build()
).build()
proto.Project.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build())
.build()
}

def where(condition: proto.Expression): proto.Relation = {
proto.Relation.newBuilder()
.setFilter(
proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
).build()
).build()
}


def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
condition: Option[proto.Expression] = None): proto.Relation = {
val relation = proto.Relation.newBuilder()
val join = proto.Join.newBuilder()
join.setLeft(logicalPlan)
join
.setLeft(logicalPlan)
.setRight(otherPlan)
.setJoinType(joinType)
if (condition.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}

/**
Expand All @@ -43,4 +44,28 @@ object DataTypeProtoConverter {
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}

def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
mode match {
case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append
case proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE => SaveMode.Ignore
case proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE => SaveMode.Overwrite
case proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS => SaveMode.ErrorIfExists
case _ =>
throw new IllegalArgumentException(
s"Cannot convert from WriteOperaton.SaveMode to Spark SaveMode: ${mode.getNumber}")
}
}

def toSaveModeProto(mode: SaveMode): proto.WriteOperation.SaveMode = {
mode match {
case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
case _ =>
throw new IllegalArgumentException(
s"Cannot convert from SaveMode to WriteOperation.SaveMode: ${mode.name()}")
}
}
}
Loading