Skip to content

Commit

Permalink
WIP process pipeing
Browse files Browse the repository at this point in the history
  • Loading branch information
szymon-rd committed Aug 9, 2023
1 parent 7f22bf8 commit 02e3ba3
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 25 deletions.
165 changes: 142 additions & 23 deletions os/src-jvm/ProcessOps.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package os

import java.util.concurrent.{ArrayBlockingQueue, Semaphore, TimeUnit}

import collection.JavaConverters._
import scala.annotation.tailrec

/**
Expand All @@ -21,8 +21,11 @@ import scala.annotation.tailrec
* the standard stdin/stdout/stderr streams, using whatever protocol you
* want
*/

case class proc(command: Shellable*) {
def commandChunks: Seq[String] = command.flatMap(_.value)
def commandChunks = command.flatMap(_.value)

def pipeTo(next: proc) = ProcGroup(Seq(this, next))

/**
* Invokes the given subprocess like a function, passing in input and returning a
Expand Down Expand Up @@ -50,7 +53,7 @@ case class proc(command: Shellable*) {
* @param propagateEnv disable this to avoid passing in this parent process's
* environment variables to the subprocess
*/
def call(
override def call(
cwd: Path = null,
env: Map[String, String] = null,
stdin: ProcessInput = Pipe,
Expand Down Expand Up @@ -79,7 +82,6 @@ case class proc(command: Shellable*) {
mergeErrIntoOut,
propagateEnv
)
import collection.JavaConverters._

sub.join(timeout)

Expand All @@ -102,7 +104,7 @@ case class proc(command: Shellable*) {
* execute in parallel with the main thread. Thus make sure any data
* processing you do in those callbacks is thread safe!
*/
def spawn(
override def spawn(
cwd: Path = null,
env: Map[String, String] = null,
stdin: ProcessInput = Pipe,
Expand All @@ -111,28 +113,12 @@ case class proc(command: Shellable*) {
mergeErrIntoOut: Boolean = false,
propagateEnv: Boolean = true
): SubProcess = {
val builder = new java.lang.ProcessBuilder()

val baseEnv =
if (propagateEnv) sys.env
else Map()
for ((k, v) <- baseEnv ++ Option(env).getOrElse(Map())) {
if (v != null) builder.environment().put(k, v)
else builder.environment().remove(k)
}

builder.directory(Option(cwd).getOrElse(os.pwd).toIO)
val builder = buildProcess(commandChunks, cwd, env, stdin, stdout, stderr, mergeErrIntoOut, propagateEnv)

val cmdChunks = commandChunks
val commandStr = cmdChunks.mkString(" ")
lazy val proc: SubProcess = new SubProcess(
builder
.command(cmdChunks: _*)
.redirectInput(stdin.redirectFrom)
.redirectOutput(stdout.redirectTo)
.redirectError(stderr.redirectTo)
.redirectErrorStream(mergeErrIntoOut)
.start(),
builder.start(),
stdin.processInput(proc.stdin).map(new Thread(_, commandStr + " stdin thread")),
stdout.processOutput(proc.stdout).map(new Thread(_, commandStr + " stdout thread")),
stderr.processOutput(proc.stderr).map(new Thread(_, commandStr + " stderr thread"))
Expand All @@ -144,3 +130,136 @@ case class proc(command: Shellable*) {
proc
}
}

case class ProcGroup(commands: Seq[proc]) {
def pipeTo(next: proc) = ProcGroup(commands :+ next)

override def call(
cwd: Path = null,
env: Map[String, String] = null,
stdin: ProcessInput = Pipe,
stdout: ProcessOutput = Pipe,
stderr: ProcessOutput = os.Inherit,
mergeErrIntoOut: Boolean = false,
timeout: Long = -1,
check: Boolean = true,
propagateEnv: Boolean = true,
failFast: Boolean = true,
pipefail: Boolean = true
): CommandResult = {
val chunks = new java.util.concurrent.ConcurrentLinkedQueue[Either[geny.Bytes, geny.Bytes]]

val sub = spawn(
cwd,
env,
stdin,
if (stdout ne os.Pipe) stdout
else os.ProcessOutput.ReadBytes((buf, n) =>
chunks.add(Left(new geny.Bytes(java.util.Arrays.copyOf(buf, n))))
),
if (stderr ne os.Pipe) stderr
else os.ProcessOutput.ReadBytes((buf, n) =>
chunks.add(Right(new geny.Bytes(java.util.Arrays.copyOf(buf, n))))
),
mergeErrIntoOut,
propagateEnv,
failFast,
pipefail
)

sub.join(timeout)

val chunksSeq = chunks.iterator.asScala.toIndexedSeq
val res = CommandResult(commandChunks, sub.exitCode(), chunksSeq)
if (res.exitCode == 0 || !check) res
else throw SubprocessException(res)
}

override def spawn(
cwd: Path = null,
env: Map[String, String] = null,
stdin: ProcessInput = Pipe,
stdout: ProcessOutput = Pipe,
stderr: ProcessOutput = os.Inherit,
mergeErrIntoOut: Boolean = false,
propagateEnv: Boolean = true,
failFast: Boolean = true,
pipefail: Boolean = true
): SubProcess = {
assert(commands.nonEmpty)
val builders = commands.zipWithIndex.map {
case (proc(command), 0) =>
(buildProcess(command, cwd, env, stdin, Pipe, stderr, mergeErrIntoOut, propagateEnv),
(proc: Process) => {
val proc = new SubProcess(
proc,
stdin.processInput(proc.stdin).map(new Thread(_, commandStr + " stdin thread")),
None,
stderr.processOutput(proc.stderr).map(new Thread(_, commandStr + " stderr thread"))
)
})
case (proc(command), commands.length - 1) =>
(buildProcess(command, cwd, env, Pipe, stdout, stderr, mergeErrIntoOut, propagateEnv),
(proc: Process) => {
val proc = new SubProcess(
proc,
None,
stdout.processOutput(proc.stdout).map(new Thread(_, commandStr + " stdout thread")),
stderr.processOutput(proc.stderr).map(new Thread(_, commandStr + " stderr thread"))
)
})
case (proc(command), index) =>
(buildProcess(command, cwd, env, Pipe, Pipe, stderr, mergeErrIntoOut, propagateEnv),
(proc: Process) => {
val proc = new SubProcess(
proc,
None,
None,
stderr.processOutput(proc.stderr).map(new Thread(_, commandStr + " stderr thread"))
)
})
}

val processes: Seq[Process] = ProcessBuilder.startPipeline(builders.asJava).asScala.toSeq
val subprocesses = builders.zip(processes).map { case ((_, f), proc) => f(proc) }
subprocesses.flatMap(p => Seq(p.inputPumperThread, p.outputPumperThread, p.errorPumperThread).flatten)
.foreach(_.start())

val pipeline = ProcessesPipeline(subprocesses, failFast, pipefail)
pipeline.brokenPipeHandler.start()
pipeline
}

}

private[os] object ProcessOps {
def buildProcess(
command: Seq[String],
cwd: Path = null,
env: Map[String, String] = null,
stdin: ProcessInput = Pipe,
stdout: ProcessOutput = Pipe,
stderr: ProcessOutput = os.Inherit,
mergeErrIntoOut: Boolean = false,
propagateEnv: Boolean = true
): ProcessBuilder = {
val builder = new java.lang.ProcessBuilder()

val baseEnv =
if (propagateEnv) sys.env
else Map()
for ((k, v) <- baseEnv ++ Option(env).getOrElse(Map())) {
if (v != null) builder.environment().put(k, v)
else builder.environment().remove(k)
}

builder.directory(Option(cwd).getOrElse(os.pwd).toIO)

builder
.command(cmdChunks: _*)
.redirectInput(stdin.redirectFrom)
.redirectOutput(stdout.redirectTo)
.redirectError(stderr.redirectTo)
.redirectErrorStream(mergeErrIntoOut)
}
}
100 changes: 98 additions & 2 deletions os/src-jvm/SubProcess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ import java.io._
import java.util.concurrent.TimeUnit

import scala.language.implicitConversions
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.LinkedTransferQueue

trait ProcessLike {

def exitCode(): Int

def isAlive(): Boolean

def destroy(): Unit

def destroyForcibly(): Unit

def close(): Unit

def waitFor(timeout: Long = -1): Boolean

def join(timeout: Long = -1): Boolean
}


/**
* Represents a spawn subprocess that has started and may or may not have
Expand All @@ -14,7 +34,7 @@ class SubProcess(
val inputPumperThread: Option[Thread],
val outputPumperThread: Option[Thread],
val errorPumperThread: Option[Thread]
) extends java.lang.AutoCloseable {
) extends java.lang.AutoCloseable with ProcessLike {
val stdin: SubProcess.InputStream = new SubProcess.InputStream(wrapped.getOutputStream)
val stdout: SubProcess.OutputStream = new SubProcess.OutputStream(wrapped.getInputStream)
val stderr: SubProcess.OutputStream = new SubProcess.OutputStream(wrapped.getErrorStream)
Expand Down Expand Up @@ -169,6 +189,82 @@ object SubProcess {
}
}

class ProcessesPipeline(
processes: Seq[SubProcess],
failFast: Boolean,
pipefail: Boolean
) extends AutoCloseable with ProcessLike {

val brokenPipeHandler: Thread = {
val finishedProcessQueue = new LinkedBlockingQueue[Int]()
val processExitListeners = processes.zipWithIndex.map { case (process, index) =>
new Thread(() => {
process.join()
finishedProcessQueue.put(index)
})
}

val pipeBreakListener = new Thread(() => {
processExitListeners.foreach(_.start())

var pipelineRunning = true
var highestBrokenPipeIndex = -1
while(pipelineRunning) {
val brokenPipeIndex = finishedProcessQueue.take()
if(brokenPipeIndex > highestBrokenPipeIndex) {
highestBrokenPipeIndex = brokenPipeIndex
if(brokenPipeIndex == processes.length - 1)
pipelineRunning = false
pipelineRunning = pipelineRunning && processes(brokenPipeIndex).exitCode() == 0

processes.take(brokenPipeIndex).filter(_.isAlive()).foreach(_.destroyForcibly())
}
}
processes.filter(_.isAlive()).foreach(_.destroyForcibly())
processExitListeners.foreach(_.join())
})
pipeBreakListener
}

override def exitCode(): Int = {
if (pipefail)
processes.map(_.exitCode())
.filter(_ != 0)
.headOption
.getOrElse(0)
else
processes.last.exitCode()
}

override def isAlive(): Boolean = {
processes.last.isAlive()
}

override def destroy(): Unit = {
processes.foreach(_.destroy())
}

override def destroyForcibly(): Unit = {
processes.foreach(_.destroyForcibly())
}

override def close(): Unit = {
processes.foreach(_.close())
}

override def waitFor(timeout: Long = -1): Boolean = {
processes.last.waitFor(timeout)
}

override def join(timeout: Long = -1): Boolean = {
processes.last.join(timeout)
}

override def close(): Boolean = {
processes.foreach(_.close())
}
}

/**
* Represents the configuration of a SubProcess's input stream. Can either be
* [[os.Inherit]], [[os.Pipe]], [[os.Path]] or a [[os.Source]]
Expand Down Expand Up @@ -273,4 +369,4 @@ case class PathRedirect(p: Path) extends ProcessInput with ProcessOutput {
case class PathAppendRedirect(p: Path) extends ProcessOutput {
def redirectTo = ProcessBuilder.Redirect.appendTo(p.toIO)
def processOutput(out: => SubProcess.OutputStream) = None
}
}

0 comments on commit 02e3ba3

Please sign in to comment.