diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala index 7ac33fa9324ac..4ecbfd123f0d2 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -39,7 +39,8 @@ class SparkConnectPlugin extends SparkPlugin { /** * Return the plugin's driver-side component. * - * @return The driver-side component. + * @return + * The driver-side component. */ override def driverPlugin(): DriverPlugin = new DriverPlugin { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index ae606a6a72edd..47d421a0359bf 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -55,10 +55,10 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) /** * This is a helper function that registers a new Python function in the SparkSession. * - * Right now this function is very rudimentary and bare-bones just to showcase how it - * is possible to remotely serialize a Python function and execute it on the Spark cluster. - * If the Python version on the client and server diverge, the execution of the function that - * is serialized will most likely fail. + * Right now this function is very rudimentary and bare-bones just to showcase how it is + * possible to remotely serialize a Python function and execute it on the Spark cluster. If the + * Python version on the client and server diverge, the execution of the function that is + * serialized will most likely fail. * * @param cf */ diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index f6553f7e90b64..401624e9882af 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -113,8 +113,11 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: proto.Relation) { def select(exprs: proto.Expression*): proto.Relation = { - proto.Relation.newBuilder().setProject( - proto.Project.newBuilder() + proto.Relation + .newBuilder() + .setProject( + proto.Project + .newBuilder() .setInput(logicalPlan) .addAllExpressions(exprs.toIterable.asJava) .build()) @@ -122,10 +125,10 @@ package object dsl { } def where(condition: proto.Expression): proto.Relation = { - proto.Relation.newBuilder() - .setFilter( - proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition) - ).build() + proto.Relation + .newBuilder() + .setFilter(proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)) + .build() } def join( @@ -145,13 +148,14 @@ package object dsl { } def as(alias: String): proto.Relation = { - proto.Relation.newBuilder(logicalPlan) + proto.Relation + .newBuilder(logicalPlan) .setCommon(proto.RelationCommon.newBuilder().setAlias(alias)) .build() } - def groupBy( - groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + def groupBy(groupingExprs: proto.Expression*)( + aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() agg.setInput(logicalPlan) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5ad95a6b516ab..46072ec089e03 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -60,7 +60,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -109,10 +110,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // TODO: support the target field for *. val projection = if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + Seq(UnresolvedStar(Option.empty)) + } else { + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) + } val project = logical.Project(projectList = projection.toSeq, child = baseRel) if (common.nonEmpty && common.get.getAlias.nonEmpty) { logical.SubqueryAlias(identifier = common.get.getAlias, child = project) @@ -141,7 +142,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * Transforms the protocol buffers literals into the appropriate Catalyst literal expression. * * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp, - * Duration, Period. + * Duration, Period. * @param lit * @return * Expression @@ -167,9 +168,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { // Days since UNIX epoch. case proto.Expression.Literal.LiteralTypeCase.DATE => expressions.Literal(lit.getDate, DateType) - case _ => throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + - s"(${lit.getLiteralTypeCase.name})") + case _ => + throw InvalidPlanInput( + s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + + s"(${lit.getLiteralTypeCase.name})") } } @@ -188,7 +190,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { * * TODO(SPARK-40546) We need to homogenize the function names for binary operators. * - * @param fun Proto representation of the function call. + * @param fun + * Proto representation of the function call. * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { @@ -278,11 +281,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { val groupingExprs = rel.getGroupingExpressionsList.asScala - .map(transformExpression) - .map { - case x @ UnresolvedAttribute(_) => x - case x => UnresolvedAlias(x) - } + .map(transformExpression) + .map { + case x @ UnresolvedAttribute(_) => x + case x => UnresolvedAlias(x) + } logical.Aggregate( child = transformRelation(rel.getInput), diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index b62917d94727e..7c494e39a69a0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.execution.ExtendedMode @Unstable @Since("3.4.0") class SparkConnectService(debug: Boolean) - extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { + extends SparkConnectServiceGrpc.SparkConnectServiceImplBase + with Logging { /** * This is the main entry method for Spark Connect and all calls to execute a plan. @@ -183,7 +184,6 @@ object SparkConnectService { /** * Starts the GRPC Serivce. - * */ def startGRPCService(): Unit = { val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true) @@ -212,4 +212,3 @@ object SparkConnectService { } } } - diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 52b807f63bb03..84a6efb2baabd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ArrowUtils - @Unstable @Since("3.4.0") class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index ba6995bfc5a82..67518f3bdb172 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -108,16 +108,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Project") { - val readWithTable = proto.Read.newBuilder() + val readWithTable = proto.Read + .newBuilder() .setNamedTable(proto.Read.NamedTable.newBuilder.addParts("name").build()) .build() val project = - proto.Project.newBuilder() + proto.Project + .newBuilder() .setInput(proto.Relation.newBuilder().setRead(readWithTable).build()) .addExpressions( - proto.Expression.newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().build()).build() - ).build() + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()) + .build()) + .build() val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res !== null) assert(res.nodeName == "Project") diff --git a/dev/lint-scala b/dev/lint-scala index 9c701ab463fe5..ad2be152cfad6 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -21,3 +21,12 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" "$SCRIPT_DIR/scalastyle" "$1" + +# For Spark Connect, we actively enforce scalafmt and check that the produced diff is empty. +./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=true -Dscalafmt.changedOnly=false -pl connector/connect +if [[ $? -ne 0 ]]; then + echo "The scalafmt check failed on connector/connect." + echo "Before submitting your change, please make sure to format your code using the following command:" + echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect" + exit 1 +fi diff --git a/dev/scalafmt b/dev/scalafmt index 56ff75fe7d383..3971f7a69e724 100755 --- a/dev/scalafmt +++ b/dev/scalafmt @@ -18,5 +18,5 @@ # VERSION="${@:-2.12}" -./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false +./build/mvn -Pscala-$VERSION scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false diff --git a/pom.xml b/pom.xml index 21aa29ef3b992..65dfcdb22340c 100644 --- a/pom.xml +++ b/pom.xml @@ -172,6 +172,8 @@ 4.7.1 true + true + true 1.9.13 2.13.4 2.13.4.1 @@ -3412,11 +3414,11 @@ mvn-scalafmt_${scala.binary.version} 1.1.1640084764.9f463a9 - ${scalafmt.skip} + ${scalafmt.validateOnly} ${scalafmt.skip} ${scalafmt.skip} dev/.scalafmt.conf - true + ${scalafmt.changedOnly}