diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala
index 7ac33fa9324ac..4ecbfd123f0d2 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala
@@ -39,7 +39,8 @@ class SparkConnectPlugin extends SparkPlugin {
/**
* Return the plugin's driver-side component.
*
- * @return The driver-side component.
+ * @return
+ * The driver-side component.
*/
override def driverPlugin(): DriverPlugin = new DriverPlugin {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
index ae606a6a72edd..47d421a0359bf 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
@@ -55,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
/**
* This is a helper function that registers a new Python function in the SparkSession.
*
- * Right now this function is very rudimentary and bare-bones just to showcase how it
- * is possible to remotely serialize a Python function and execute it on the Spark cluster.
- * If the Python version on the client and server diverge, the execution of the function that
- * is serialized will most likely fail.
+ * Right now this function is very rudimentary and bare-bones just to showcase how it is
+ * possible to remotely serialize a Python function and execute it on the Spark cluster. If the
+ * Python version on the client and server diverge, the execution of the function that is
+ * serialized will most likely fail.
*
* @param cf
*/
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index f6553f7e90b64..401624e9882af 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -113,8 +113,11 @@ package object dsl {
object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
def select(exprs: proto.Expression*): proto.Relation = {
- proto.Relation.newBuilder().setProject(
- proto.Project.newBuilder()
+ proto.Relation
+ .newBuilder()
+ .setProject(
+ proto.Project
+ .newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build())
@@ -122,10 +125,10 @@ package object dsl {
}
def where(condition: proto.Expression): proto.Relation = {
- proto.Relation.newBuilder()
- .setFilter(
- proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
- ).build()
+ proto.Relation
+ .newBuilder()
+ .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition))
+ .build()
}
def join(
@@ -145,13 +148,14 @@ package object dsl {
}
def as(alias: String): proto.Relation = {
- proto.Relation.newBuilder(logicalPlan)
+ proto.Relation
+ .newBuilder(logicalPlan)
.setCommon(proto.RelationCommon.newBuilder().setAlias(alias))
.build()
}
- def groupBy(
- groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = {
+ def groupBy(groupingExprs: proto.Expression*)(
+ aggregateExprs: proto.Expression*): proto.Relation = {
val agg = proto.Aggregate.newBuilder()
agg.setInput(logicalPlan)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5ad95a6b516ab..46072ec089e03 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -60,7 +60,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
- case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation)
+ case proto.Relation.RelTypeCase.LOCAL_RELATION =>
+ transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -109,10 +110,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
// TODO: support the target field for *.
val projection =
if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) {
- Seq(UnresolvedStar(Option.empty))
- } else {
- rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
- }
+ Seq(UnresolvedStar(Option.empty))
+ } else {
+ rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
+ }
val project = logical.Project(projectList = projection.toSeq, child = baseRel)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
logical.SubqueryAlias(identifier = common.get.getAlias, child = project)
@@ -141,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
* Transforms the protocol buffers literals into the appropriate Catalyst literal expression.
*
* TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp,
- * Duration, Period.
+ * Duration, Period.
* @param lit
* @return
* Expression
@@ -167,9 +168,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
// Days since UNIX epoch.
case proto.Expression.Literal.LiteralTypeCase.DATE =>
expressions.Literal(lit.getDate, DateType)
- case _ => throw InvalidPlanInput(
- s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
- s"(${lit.getLiteralTypeCase.name})")
+ case _ =>
+ throw InvalidPlanInput(
+ s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
+ s"(${lit.getLiteralTypeCase.name})")
}
}
@@ -188,7 +190,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
*
* TODO(SPARK-40546) We need to homogenize the function names for binary operators.
*
- * @param fun Proto representation of the function call.
+ * @param fun
+ * Proto representation of the function call.
* @return
*/
private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
@@ -278,11 +281,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
val groupingExprs =
rel.getGroupingExpressionsList.asScala
- .map(transformExpression)
- .map {
- case x @ UnresolvedAttribute(_) => x
- case x => UnresolvedAlias(x)
- }
+ .map(transformExpression)
+ .map {
+ case x @ UnresolvedAttribute(_) => x
+ case x => UnresolvedAlias(x)
+ }
logical.Aggregate(
child = transformRelation(rel.getInput),
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index b62917d94727e..7c494e39a69a0 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -49,7 +49,8 @@ import org.apache.spark.sql.execution.ExtendedMode
@Unstable
@Since("3.4.0")
class SparkConnectService(debug: Boolean)
- extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging {
+ extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
+ with Logging {
/**
* This is the main entry method for Spark Connect and all calls to execute a plan.
@@ -183,7 +184,6 @@ object SparkConnectService {
/**
* Starts the GRPC Serivce.
- *
*/
def startGRPCService(): Unit = {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
@@ -212,4 +212,3 @@ object SparkConnectService {
}
}
}
-
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 52b807f63bb03..84a6efb2baabd 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.ArrowUtils
-
@Unstable
@Since("3.4.0")
class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index ba6995bfc5a82..67518f3bdb172 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -108,16 +108,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}
test("Simple Project") {
- val readWithTable = proto.Read.newBuilder()
+ val readWithTable = proto.Read
+ .newBuilder()
.setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build())
.build()
val project =
- proto.Project.newBuilder()
+ proto.Project
+ .newBuilder()
.setInput(proto.Relation.newBuilder().setRead(readWithTable).build())
.addExpressions(
- proto.Expression.newBuilder()
- .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build()
- ).build()
+ proto.Expression
+ .newBuilder()
+ .setUnresolvedStar(UnresolvedStar.newBuilder().build())
+ .build())
+ .build()
val res = transform(proto.Relation.newBuilder.setProject(project).build())
assert(res !== null)
assert(res.nodeName == "Project")
diff --git a/dev/lint-scala b/dev/lint-scala
index 9c701ab463fe5..ad2be152cfad6 100755
--- a/dev/lint-scala
+++ b/dev/lint-scala
@@ -21,3 +21,12 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"
"$SCRIPT_DIR/scalastyle" "$1"
+
+# For Spark Connect, we actively enforce scalafmt and check that the produced diff is empty.
+./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -Dscalafmt.changedOnly=false -pl connector/connect
+if [[ $? -ne 0 ]]; then
+ echo "The scalafmt check failed on connector/connect."
+ echo "Before submitting your change, please make sure to format your code using the following command:"
+ echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect"
+ exit 1
+fi
diff --git a/dev/scalafmt b/dev/scalafmt
index 56ff75fe7d383..3971f7a69e724 100755
--- a/dev/scalafmt
+++ b/dev/scalafmt
@@ -18,5 +18,5 @@
#
VERSION="${@:-2.12}"
-./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false
+./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false
diff --git a/pom.xml b/pom.xml
index 21aa29ef3b992..65dfcdb22340c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -172,6 +172,8 @@
4.7.1
true
+ true
+ true
1.9.13
2.13.4
2.13.4.1
@@ -3412,11 +3414,11 @@
mvn-scalafmt_${scala.binary.version}
1.1.1640084764.9f463a9
- ${scalafmt.skip}
+ ${scalafmt.validateOnly}
${scalafmt.skip}
${scalafmt.skip}
dev/.scalafmt.conf
- true
+ ${scalafmt.changedOnly}