Skip to content

Commit

Permalink
ScalafmtRunner: use separate execution contexts
Browse files Browse the repository at this point in the history
- in both core and dynamic runners, reading and writing is done the same
  way, with the only difference being formatting; hence, let's move I/O
  to shared runner
- to avoid reading all files first and only then formatting and writing
  them, let's create separate input and output execution contexts and
  use the former for reading and formatting, and the latter for writing
- finally, let's define parasitic execution context (one which executes
  immediately, without putting the task in a queue), for short tasks
  such as triggering `.onComplete` or updating task progress
  • Loading branch information
kitbellew committed Feb 18, 2025
1 parent b822e23 commit c2e993d
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 131 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.scalafmt.cli

import org.scalafmt.sysops.AbsoluteFile
import org.scalafmt.sysops.PlatformRunOps.executionContext
import org.scalafmt.sysops.PlatformRunOps

import scala.io.Source

Expand All @@ -17,6 +17,7 @@ private[scalafmt] trait CliUtils {
.getWorkingDirectory}",
),
)
import PlatformRunOps.parasiticExecutionContext
Cli.mainWithOptions(
CliOptions.default.copy(common =
CliOptions.default.common.copy(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
package org.scalafmt.cli

import org.scalafmt.Error
import org.scalafmt.dynamic.ScalafmtDynamicError
import org.scalafmt.interfaces.Scalafmt
import org.scalafmt.interfaces.ScalafmtSession
import org.scalafmt.sysops.PlatformFileOps
import org.scalafmt.sysops.PlatformRunOps

import java.nio.file.Path

import scala.concurrent.Future

object ScalafmtDynamicRunner extends ScalafmtRunner {
import org.scalafmt.sysops.PlatformRunOps.executionContext

override private[cli] def run(
options: CliOptions,
Expand All @@ -31,7 +28,7 @@ object ScalafmtDynamicRunner extends ScalafmtRunner {

private def runWithSession(
options: CliOptions,
termDisplayMessage: String,
displayMsg: String,
reporter: ScalafmtCliReporter,
)(session: ScalafmtSession): Future[ExitCode] = {
val sessionMatcher = session.matchesProjectFilters _
Expand All @@ -42,23 +39,13 @@ object ScalafmtDynamicRunner extends ScalafmtRunner {
}
val inputMethods = getInputMethods(options, filterMatcher)
if (inputMethods.isEmpty) ExitCode.Ok.future
else runInputs(options, inputMethods, termDisplayMessage)(inputMethod =>
handleFile(inputMethod, session, options).recover {
case x: Error.MisformattedFile => reporter.fail(x)(x.file)
}.map(ExitCode.merge(_, reporter.getExitCode)),
)
else runInputs(options, inputMethods, displayMsg) { case (code, path) =>
val formatted = session.format(path, code)
val exitCode = reporter.getExitCode(path)
if (exitCode eq null) Right(formatted) else Left(exitCode)
}
}

private[this] def handleFile(
inputMethod: InputMethod,
session: ScalafmtSession,
options: CliOptions,
): Future[ExitCode] = inputMethod.readInput(options)
.map(code => code -> session.format(inputMethod.path, code))
.flatMap { case (code, formattedCode) =>
inputMethod.write(formattedCode, code, options)
}(PlatformRunOps.ioExecutionContext)

private def getFileMatcher(paths: Seq[Path]): Path => Boolean = {
val dirBuilder = Seq.newBuilder[Path]
val fileBuilder = Set.newBuilder[Path]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.concurrent.Future

object Cli extends CliUtils {

import PlatformRunOps.executionContext
import PlatformRunOps.parasiticExecutionContext

def main(args: Array[String]): Unit =
mainWithOptions(CliOptions.default, args: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ sealed abstract class InputMethod {
case WriteMode.Stdout => print(formatted, options); exitCode.future
case _ if !codeChanged => ExitCode.Ok.future
case WriteMode.List => list(options); options.exitCodeOnChange.future
case WriteMode.Override => overwrite(formatted, options)
.map(_ => options.exitCodeOnChange)(PlatformRunOps.ioExecutionContext)
case WriteMode.Override => overwrite(formatted, options).map(_ =>
options.exitCodeOnChange,
)(PlatformRunOps.parasiticExecutionContext)
case WriteMode.Test =>
val pathStr = path.toString
val diff = InputMethod.unifiedDiff(pathStr, original, formatted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,31 @@ import org.scalafmt.interfaces._
import java.io.OutputStreamWriter
import java.io.PrintWriter
import java.nio.file.Path
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicReference

import scala.annotation.tailrec
import scala.util.control.NoStackTrace

class ScalafmtCliReporter(options: CliOptions) extends ScalafmtReporter {
private val exitCode = new AtomicReference(ExitCode.Ok)
private val exitCodePerFile = new ConcurrentHashMap[String, ExitCode]()

def getExitCode: ExitCode = exitCode.get()
private def updateExitCode(code: ExitCode): Unit =
if (!code.isOk) exitCode.getAndUpdate(ExitCode.merge(code, _))
def getExitCode(file: Path): ExitCode = exitCodePerFile.get(file.toString)

private def updateExitCode(code: ExitCode, file: Path): Unit = if (!code.isOk) {
exitCodePerFile.put(file.toString, code)
exitCode.getAndUpdate(ExitCode.merge(code, _))
}

override def error(file: Path, message: String): Unit =
if (!options.ignoreWarnings) {
options.common.err.println(s"$message: $file")
updateExitCode(ExitCode.UnexpectedError)
updateExitCode(ExitCode.UnexpectedError, file)
}
override final def error(file: Path, e: Throwable): Unit =
updateExitCode(fail(e)(file))
updateExitCode(fail(e)(file), file)
@tailrec
private[cli] final def fail(e: Throwable)(file: Path): ExitCode = e match {
case e: MisformattedFile =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import org.scalafmt.Versions
import org.scalafmt.config.ProjectFiles
import org.scalafmt.config.ScalafmtConfig
import org.scalafmt.config.ScalafmtConfigException
import org.scalafmt.sysops.PlatformRunOps

import scala.meta.parsers.ParseException
import scala.meta.tokenizers.TokenizeException
Expand All @@ -16,7 +15,6 @@ import scala.annotation.tailrec
import scala.concurrent.Future

object ScalafmtCoreRunner extends ScalafmtRunner {
import org.scalafmt.sysops.PlatformRunOps.executionContext

override private[cli] def run(
options: CliOptions,
Expand All @@ -36,10 +34,7 @@ object ScalafmtCoreRunner extends ScalafmtRunner {
}
}

private def runWithFilterMatcher(
options: CliOptions,
termDisplayMessage: String,
)(
private def runWithFilterMatcher(options: CliOptions, displayMsg: String)(
filterMatcher: ProjectFiles.FileMatcher,
)(implicit scalafmtConf: ScalafmtConfig): Future[ExitCode] = {
val inputMethods = getInputMethods(options, filterMatcher.matchesPath)
Expand All @@ -48,23 +43,18 @@ object ScalafmtCoreRunner extends ScalafmtRunner {
val adjustedScalafmtConf = {
if (scalafmtConf.needGitAutoCRLF) options.gitOps.getAutoCRLF else None
}.fold(scalafmtConf)(scalafmtConf.withGitAutoCRLF)

runInputs(options, inputMethods, termDisplayMessage)(inputMethod =>
handleFile(inputMethod, options, adjustedScalafmtConf)
.recover { case e: Error.MisformattedFile =>
options.common.err.println(e.customMessage)
ExitCode.TestError
},
)
runInputs(options, inputMethods, displayMsg) { case (code, path) =>
handleFile(code, path.toString, options, adjustedScalafmtConf)
}
}
}

private[this] def handleFile(
inputMethod: InputMethod,
code: String,
path: String,
options: CliOptions,
scalafmtConfig: ScalafmtConfig,
): Future[ExitCode] = {
val path = inputMethod.path.toString
): Either[ExitCode, String] = {
@tailrec
def handleError(e: Throwable): ExitCode = e match {
case Error.WithCode(e, _) => handleError(e)
Expand All @@ -75,20 +65,14 @@ object ScalafmtCoreRunner extends ScalafmtRunner {
new FailedToFormat(path, e).printStackTrace(options.common.err)
ExitCode.UnexpectedError
}
inputMethod.readInput(options).map { code =>
val res = Scalafmt.formatCode(code, scalafmtConfig, options.range, path)
res.formatted match {
case x: Formatted.Success => Right(code -> x.formattedCode)
case x: Formatted.Failure => Left(
if (res.config.runner.ignoreWarnings) ExitCode.Ok // do nothing
else handleError(x.e),
)
}
}.flatMap {
case Right((code, formattedCode)) => inputMethod
.write(formattedCode, code, options)
case Left(exitCode) => exitCode.future
}(PlatformRunOps.ioExecutionContext)
val res = Scalafmt.formatCode(code, scalafmtConfig, options.range, path)
res.formatted match {
case x: Formatted.Success => Right(x.formattedCode)
case x: Formatted.Failure => Left(
if (res.config.runner.ignoreWarnings) ExitCode.Ok // do nothing
else handleError(x.e),
)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,32 @@ import org.scalafmt.sysops._
import java.nio.file.Path

import scala.concurrent._
import scala.util.Failure
import scala.util.Success

trait ScalafmtRunner {
private[cli] def run(
options: CliOptions,
termDisplayMessage: String,
): Future[ExitCode]

protected def newTermDisplay(
private def newTermDisplay(
options: CliOptions,
inputMethods: Seq[InputMethod],
msg: String,
): TermDisplay = {
val termDisplay = new TermDisplay(
options.common.info.printWriter,
fallbackMode = options.nonInteractive || TermDisplay.defaultFallbackMode,
)
): Option[TermDisplay] =
if (
options.writeMode != WriteMode.Stdout && inputMethods.lengthCompare(5) > 0
) {
val termDisplay = new TermDisplay(
options.common.info.printWriter,
fallbackMode = options.nonInteractive || TermDisplay.defaultFallbackMode,
)
termDisplay.init()
termDisplay.startTask(msg, options.cwd.jfile)
termDisplay.taskLength(msg, inputMethods.length, 0)
}
termDisplay
}
Some(termDisplay)
} else None

protected def getInputMethods(
options: CliOptions,
Expand Down Expand Up @@ -74,35 +75,50 @@ trait ScalafmtRunner {
options: CliOptions,
inputMethods: Seq[InputMethod],
termDisplayMessage: String,
)(f: InputMethod => Future[ExitCode]): Future[ExitCode] = {
)(f: (String, Path) => Either[ExitCode, String]): Future[ExitCode] = {
val termDisplay = newTermDisplay(options, inputMethods, termDisplayMessage)

implicit val executionContext: ExecutionContext =
PlatformRunOps.executionContext
import PlatformRunOps.parasiticExecutionContext
val completed = Promise[ExitCode]()

val tasks = List.newBuilder[Future[ExitCode]]
inputMethods.foreach(inputMethod =>
inputMethods.foreach { inputMethod =>
if (!completed.isCompleted) {
val future = f(inputMethod)
future.onComplete(r =>
if (options.check && !r.toOption.exists(_.isOk)) completed
.tryComplete(r)
else termDisplay.taskProgress(termDisplayMessage),
val input = inputMethod.readInput(options)
val future = input.map(code =>
f(code, inputMethod.path).map(formatted => (code, formatted)),
).flatMap {
case Left(exitCode) => exitCode.future
case Right((code, formattedCode)) => inputMethod
.write(formattedCode, code, options).transform {
case Failure(e: Error.MisformattedFile) =>
options.common.err.println(e.customMessage)
Success(ExitCode.TestError)
case r =>
if (r.toOption.exists(_.isOk)) termDisplay
.foreach(_.taskProgress(termDisplayMessage))
r
}
}
if (options.check) future.onComplete(r =>
if (!r.toOption.exists(_.isOk)) completed.tryComplete(r),
)
tasks += future
},
)
}
}

Future.foldLeft(tasks.result())(ExitCode.Ok)(ExitCode.merge)
.onComplete(completed.tryComplete)

completed.future.onComplete { r =>
termDisplay.completedTask(termDisplayMessage, r.toOption.exists(_.isOk))
termDisplay.stop()
}
val res = completed.future
termDisplay.fold(res)(td =>
res.transform { r =>
td.completedTask(termDisplayMessage, r.toOption.exists(_.isOk))
td.stop()
r
},
)

completed.future
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object TermDisplay extends TermUtils {
private def shouldUpdate(): Boolean = shouldUpdateFlag
.compareAndSet(true, false)

def end(): Unit = if (isStarted.get()) {
def end(): Unit = if (isStarted.compareAndSet(true, false)) {
polling.cancel()
if (fallbackMode) processStopFallback() else processStop()
}
Expand Down Expand Up @@ -255,7 +255,7 @@ object TermDisplay extends TermUtils {
-info.fraction.sum
}

private def processStop(): Unit = {} // poison pill
private def processStop(): Unit = out.append("\n\n").flush() // poison pill

private def processUpdate(): Unit = if (shouldUpdate()) {
val (done0, downloads0) = downloads.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import munit.FunSuite

abstract class AbstractCliTest extends FunSuite {

import org.scalafmt.sysops.PlatformRunOps.executionContext

def mkArgs(str: String): Array[String] = str.split(' ')

def runWith(root: AbsoluteFile, argStr: String)(implicit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ private[scalafmt] object PlatformRunOps {
implicit def executionContext: ExecutionContext =
scala.scalajs.concurrent.JSExecutionContext.Implicits.queue

def ioExecutionContext: ExecutionContext = executionContext

def getSingleThreadExecutionContext: ExecutionContext = executionContext // same one
implicit def parasiticExecutionContext: ExecutionContext =
GranularDialectAsyncOps.parasiticExecutionContext

def runArgv(cmd: Seq[String], cwd: Option[Path]): Try[String] = {
val options = cwd.fold(js.Dictionary[js.Any]())(cwd =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import java.nio.file.Path
import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContextExecutorService
import scala.sys.process.ProcessLogger
import scala.util.Failure
import scala.util.Success
Expand All @@ -13,11 +14,17 @@ private[scalafmt] object PlatformRunOps {

implicit def executionContext: ExecutionContext = ExecutionContext.global

def ioExecutionContext: ExecutionContext =
GranularPlatformAsyncOps.ioExecutionContext

def getSingleThreadExecutionContext: ExecutionContext = ExecutionContext
.fromExecutor(Executors.newSingleThreadExecutor())
val inputExecutionContext: ExecutionContextExecutorService = ExecutionContext
.fromExecutorService(
Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors()),
)
val outputExecutionContext: ExecutionContextExecutorService = ExecutionContext
.fromExecutorService(
Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors()),
)

implicit def parasiticExecutionContext: ExecutionContext =
GranularDialectAsyncOps.parasiticExecutionContext

def runArgv(cmd: Seq[String], cwd: Option[Path]): Try[String] = {
val err = new StringBuilder()
Expand Down
Loading

0 comments on commit c2e993d

Please sign in to comment.