Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class SparkConnectPlugin extends SparkPlugin {
/**
* Return the plugin's driver-side component.
*
* @return The driver-side component.
* @return
* The driver-side component.
*/
override def driverPlugin(): DriverPlugin = new DriverPlugin {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types.StringType


@Unstable
@Since("3.4.0")
class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) {
Expand All @@ -47,10 +46,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
/**
* This is a helper function that registers a new Python function in the SparkSession.
*
* Right now this function is very rudimentary and bare-bones just to showcase how it
* is possible to remotely serialize a Python function and execute it on the Spark cluster.
* If the Python version on the client and server diverge, the execution of the function that
* is serialized will most likely fail.
* Right now this function is very rudimentary and bare-bones just to showcase how it is
* possible to remotely serialize a Python function and execute it on the Spark cluster. If the
* Python version on the client and server diverge, the execution of the function that is
* serialized will most likely fail.
*
* @param cf
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,59 +34,70 @@ package object dsl {
val identifier = CatalystSqlParser.parseMultipartIdentifier(s)

def protoAttr: proto.Expression =
proto.Expression.newBuilder()
proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder()
proto.Expression.UnresolvedAttribute
.newBuilder()
.addAllParts(identifier.asJava)
.build())
.build()
}

implicit class DslExpression(val expr: proto.Expression) {
def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias(
proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build()
def as(alias: String): proto.Expression = proto.Expression
.newBuilder()
.setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.build()

def < (other: proto.Expression): proto.Expression =
proto.Expression.newBuilder().setUnresolvedFunction(
proto.Expression.UnresolvedFunction.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other)
).build()
def <(other: proto.Expression): proto.Expression =
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression.newBuilder().setLiteral(
proto.Expression.Literal.newBuilder().setI32(i)
).build()
proto.Expression
.newBuilder()
.setLiteral(proto.Expression.Literal.newBuilder().setI32(i))
.build()
}

object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
def select(exprs: proto.Expression*): proto.Relation = {
proto.Relation.newBuilder().setProject(
proto.Project.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build()
).build()
proto.Relation
.newBuilder()
.setProject(
proto.Project
.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build())
.build()
}

def where(condition: proto.Expression): proto.Relation = {
proto.Relation.newBuilder()
.setFilter(
proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
).build()
proto.Relation
.newBuilder()
.setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
.build()
}


def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
condition: Option[proto.Expression] = None): proto.Relation = {
val relation = proto.Relation.newBuilder()
val join = proto.Join.newBuilder()
join.setLeft(logicalPlan)
join
.setLeft(logicalPlan)
.setRight(otherPlan)
.setJoinType(joinType)
if (condition.isDefined) {
Expand All @@ -95,8 +106,8 @@ package object dsl {
relation.setJoin(join).build()
}

def groupBy(
groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = {
def groupBy(groupingExprs: proto.Expression*)(
aggregateExprs: proto.Expression*): proto.Relation = {
val agg = proto.Aggregate.newBuilder()
agg.setInput(logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
Expand Down Expand Up @@ -109,10 +110,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
// TODO: support the target field for *.
val projection =
if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) {
Seq(UnresolvedStar(Option.empty))
} else {
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
}
Seq(UnresolvedStar(Option.empty))
} else {
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
}
val project = logical.Project(projectList = projection.toSeq, child = baseRel)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
logical.SubqueryAlias(identifier = common.get.getAlias, child = project)
Expand Down Expand Up @@ -141,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
* Transforms the protocol buffers literals into the appropriate Catalyst literal expression.
*
* TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp,
* Duration, Period.
* Duration, Period.
* @param lit
* @return
* Expression
Expand All @@ -167,9 +168,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
// Days since UNIX epoch.
case proto.Expression.Literal.LiteralTypeCase.DATE =>
expressions.Literal(lit.getDate, DateType)
case _ => throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
s"(${lit.getLiteralTypeCase.name})")
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
s"(${lit.getLiteralTypeCase.name})")
}
}

Expand All @@ -188,7 +190,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
*
* TODO(SPARK-40546) We need to homogenize the function names for binary operators.
*
* @param fun Proto representation of the function call.
* @param fun
* Proto representation of the function call.
* @return
*/
private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
Expand Down Expand Up @@ -278,11 +281,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

val groupingExprs =
rel.getGroupingExpressionsList.asScala
.map(transformExpression)
.map {
case x @ UnresolvedAttribute(_) => x
case x => UnresolvedAlias(x)
}
.map(transformExpression)
.map {
case x @ UnresolvedAttribute(_) => x
case x => UnresolvedAlias(x)
}

logical.Aggregate(
child = transformRelation(rel.getInput),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ import org.apache.spark.sql.execution.ExtendedMode
@Unstable
@Since("3.4.0")
class SparkConnectService(debug: Boolean)
extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging {
extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
with Logging {

/**
* This is the main entry method for Spark Connect and all calls to execute a plan.
Expand Down Expand Up @@ -183,7 +184,6 @@ object SparkConnectService {

/**
* Starts the GRPC Serivce.
*
*/
def startGRPCService(): Unit = {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
Expand Down Expand Up @@ -212,4 +212,3 @@ object SparkConnectService {
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.ArrowUtils


@Unstable
@Since("3.4.0")
class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("Simple Project") {
val readWithTable = proto.Read.newBuilder()
val readWithTable = proto.Read
.newBuilder()
.setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build())
.build()
val project =
proto.Project.newBuilder()
proto.Project
.newBuilder()
.setInput(proto.Relation.newBuilder().setRead(readWithTable).build())
.addExpressions(
proto.Expression.newBuilder()
.setUnresolvedStar(UnresolvedStar.newBuilder().build()).build()
).build()
proto.Expression
.newBuilder()
.setUnresolvedStar(UnresolvedStar.newBuilder().build())
.build())
.build()
val res = transform(proto.Relation.newBuilder.setProject(project).build())
assert(res !== null)
assert(res.nodeName == "Project")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None)
comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false)
for ((t, y) <- Seq(
(JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter),
(JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter),
(JoinType.JOIN_TYPE_FULL_OUTER, FullOuter),
(JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti),
(JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi),
(JoinType.JOIN_TYPE_INNER, Inner))) {
(JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter),
(JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter),
(JoinType.JOIN_TYPE_FULL_OUTER, FullOuter),
(JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti),
(JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi),
(JoinType.JOIN_TYPE_INNER, Inner))) {
val connectPlan3 = {
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.join(connectTestRelation2, t))
Expand Down Expand Up @@ -115,10 +115,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val localRelationBuilder = proto.LocalRelation.newBuilder()
for (attr <- attrs) {
localRelationBuilder.addAttributes(
proto.Expression.QualifiedAttribute.newBuilder()
proto.Expression.QualifiedAttribute
.newBuilder()
.setName(attr.name)
.setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))
)
.setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)))
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
}
Expand Down
10 changes: 10 additions & 0 deletions dev/lint-scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"

"$SCRIPT_DIR/scalastyle" "$1"

# For Spark Connect, we actively enforce scalafmt and check that the produced diff is empty.
./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -pl connector/connect
if [[ $? -ne 0 ]]; then
echo "The scalafmt check failed on connector/connect."
echo "Before submitting your change, please make sure to format your code using the following command:"
echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -pl connector/connect"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) No one loves SBT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with SBT though maybe this works with SBT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece here is taken from the existing dev/scalafmt script and slightly adjusted to only reformat the Spark Connect module. I don't think we've ported the usage of scalafmt to SBT.

exit 1
fi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @grundprinzip and @HyukjinKwon .

This PR makes lint-scala ignore the exit code of scalastyle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if [[ $? -ne 0 ]]; then

This checks explicitly for a non-zero exit code. Why would this not cover your case?


2 changes: 1 addition & 1 deletion dev/scalafmt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#

VERSION="${@:-2.12}"
./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false
./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false

3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
<scala-maven-plugin.version>4.7.1</scala-maven-plugin.version>
<!-- for now, not running scalafmt as part of default verify pipeline -->
<scalafmt.skip>true</scalafmt.skip>
<scalafmt.validateOnly>true</scalafmt.validateOnly>
<codehaus.jackson.version>1.9.13</codehaus.jackson.version>
<fasterxml.jackson.version>2.13.4</fasterxml.jackson.version>
<fasterxml.jackson.databind.version>2.13.4.1</fasterxml.jackson.databind.version>
Expand Down Expand Up @@ -3412,7 +3413,7 @@
<artifactId>mvn-scalafmt_${scala.binary.version}</artifactId>
<version>1.1.1640084764.9f463a9</version>
<configuration>
<validateOnly>${scalafmt.skip}</validateOnly> <!-- (Optional) skip formatting -->
<validateOnly>${scalafmt.validateOnly}</validateOnly> <!-- (Optional) skip formatting -->
<skipSources>${scalafmt.skip}</skipSources>
<skipTestSources>${scalafmt.skip}</skipTestSources>
<configLocation>dev/.scalafmt.conf</configLocation> <!-- (Optional) config location -->
Expand Down