Skip to content

Commit

Permalink
Merge pull request #1192 from alexarchambault/custom-directive-groups
Browse files Browse the repository at this point in the history
Add support for custom directive groups
  • Loading branch information
alexarchambault authored Jul 6, 2023
2 parents 6c39c02 + 9194e87 commit da03dc1
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 33 deletions.
1 change: 1 addition & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ trait Launcher extends AlmondSimpleModule with BootstrapLauncher with PropertyFi
}
Seq(
"kernel-main-class" -> mainClass,
"ammonite-version" -> Versions.ammonite,
"default-scala212-version" -> ScalaVersions.scala212,
"default-scala213-version" -> ScalaVersions.scala213,
"default-scala-version" -> ScalaVersions.scala3Latest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,71 @@ class KernelTestsTwoStepStartup213 extends KernelTestsDefinitions {
}
}

test("custom startup directives") {

val tmpDir = os.temp.dir(prefix = "almond-custom-directives-test-")

val handlerCode =
"""//> using lib "com.lihaoyi::os-lib:0.9.1"
|//> using lib "com.lihaoyi::pprint:0.8.1"
|//> using lib "com.lihaoyi::upickle:3.1.0"
|//> using scala "2.13.11"
|//> using jvm "8"
|
|package handler
|
|object Handler {
| case class Entry(key: String, values: List[String])
| implicit val entryCodec: upickle.default.ReadWriter[Entry] = upickle.default.macroRW
|
| def main(args: Array[String]): Unit = {
| assert(args.length == 1, "Usage: handle-spark-directives input.json")
| val inputEntries = upickle.default.read[List[Entry]](os.read.bytes(os.Path(args(0), os.pwd)))
| pprint.err.log(inputEntries)
|
| val version = inputEntries.find(_.key == "spark.version").flatMap(_.values.headOption).getOrElse("X.Y")
|
| val output = ujson.Obj(
| "javaCmd" -> ujson.Arr(Seq("java", s"-Dthe-version=$version").map(ujson.Str(_)): _*)
| )
|
| println(output.render())
| }
|}
|""".stripMargin

val directivesHandler = tmpDir / "handle-spark-directives"

os.write(tmpDir / "Handler.scala", handlerCode)

os.proc(
Tests.java17Cmd,
"-jar",
Tests.scalaCliLauncher.toString,
"--power",
"package",
".",
"-o",
directivesHandler
).call(cwd = tmpDir)

kernelLauncher.withKernel { implicit runner =>
implicit val sessionId: SessionId = SessionId()
runner.withSession("--custom-directive-group", s"spark:$directivesHandler") {
implicit session =>
execute(
"""//> using spark.version "1.2.3"
|//> using spark
|""".stripMargin,
""
)

execute(
"""val version = sys.props.getOrElse("the-version", "nope")""",
"""version: String = "1.2.3""""
)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,7 @@ object Launcher extends CaseApp[LauncherOptions] {
logCtx: LoggerContext
): (os.proc, String, Option[String]) = {

val requestedScalaVersion = params0.scala
.orElse(options.scala.map(_.trim).filter(_.nonEmpty))
.getOrElse(Properties.defaultScalaVersion)

val scalaVersion = requestedScalaVersion match {
case "2.12" => Properties.defaultScala212Version
case "2" | "2.13" => Properties.defaultScala213Version
case "3" => Properties.defaultScalaVersion
case _ => requestedScalaVersion
}
val (scalaVersion, _) = LauncherInterpreter.computeScalaVersion(params0, options)

def content(entries: Seq[(coursierapi.Artifact, File)]): ClassLoaderContent = {
val entries0 = entries.map {
Expand Down Expand Up @@ -170,9 +161,9 @@ object Launcher extends CaseApp[LauncherOptions] {
)
val javaHome = os.Path(jvmManager.get(jvmId), os.pwd)
val ext = if (scala.util.Properties.isWin) ".exe" else ""
(javaHome / "bin" / s"java$ext").toString
Seq((javaHome / "bin" / s"java$ext").toString)
case None =>
"java"
params0.javaCmd.getOrElse(Seq("java"))
}

val javaOptions = options.javaOpt ++ params0.javaOptions
Expand Down Expand Up @@ -200,7 +191,7 @@ object Launcher extends CaseApp[LauncherOptions] {
options.kernelOptions
)

(proc, requestedScalaVersion, jvmIdOpt)
(proc, scalaVersion, jvmIdOpt)
}

private def launchActualKernel(proc: os.proc): Unit = {
Expand Down Expand Up @@ -372,7 +363,7 @@ object Launcher extends CaseApp[LauncherOptions] {
interpreter.lineCount,
options,
firstMessageIdOpt.toSeq,
interpreter.params,
interpreter.params.processCustomDirectives(),
interpreter.kernelOptions,
outputHandlerOpt.getOrElse(OutputHandler.NopOutputHandler),
logCtx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,95 @@ import almond.protocol.KernelInfo
import java.io.File

import scala.cli.directivehandler._
import scala.cli.directivehandler.DirectiveValueParser.DirectiveValueParserValueOps
import scala.cli.directivehandler.EitherSequence._

class LauncherInterpreter(
connectionFile: String,
options: LauncherOptions
) extends Interpreter {

def kernelInfo(): KernelInfo =
def kernelInfo(): KernelInfo = {
val (sv, svOrigin) = LauncherInterpreter.computeScalaVersion(params, options)
KernelInfo(
implementation = "scala",
implementation_version = "???",
implementation_version = Properties.version,
language_info = KernelInfo.LanguageInfo(
name = "scala",
version = "???",
version = Properties.version,
mimetype = "text/x-scala",
file_extension = ".sc",
nbconvert_exporter = "script",
codemirror_mode = Some("text/x-scala")
),
banner =
s"""Almond ${"???"}
|Ammonite ${"???"}
|${"???"}
|Java ${"???"}""".stripMargin, // +
s"""Almond ${Properties.version}
|Ammonite ${Properties.ammoniteVersion}
|Scala $sv (from $svOrigin)""".stripMargin, // +
// params.extraBannerOpt.fold("")("\n\n" + _),
help_links = None // Some(params.extraLinks.toList).filter(_.nonEmpty)
)
}

var kernelOptions = KernelOptions()
var params = LauncherParameters()

val customDirectiveGroups = options.customDirectiveGroupsOrExit()
val launcherParametersHandlers =
LauncherInterpreter.launcherParametersHandlers.addCustomHandler { key =>
customDirectiveGroups.find(_.matches(key)).map { group =>
new DirectiveHandler[HasLauncherParameters] {
def name = s"custom group ${group.prefix}"
def description = s"custom group ${group.prefix}"
def usage = s"//> ${group.prefix}..."

def keys = Seq(key)
def handleValues(scopedDirective: ScopedDirective)
: Either[DirectiveException, ProcessedDirective[HasLauncherParameters]] = {
assert(scopedDirective.directive.key == key)
val maybeValues = scopedDirective.directive.values
.filter(!_.isEmpty)
.map { value =>
value.asString.toRight {
new MalformedDirectiveError(
s"Expected a string, got '${value.getRelatedASTNode.toString}'",
Seq(value.position(scopedDirective.maybePath))
)
}
}
.sequence
.left.map(CompositeDirectiveException(_))
maybeValues.map { values =>
ProcessedDirective(
Some(
new HasLauncherParameters {
def launcherParameters =
LauncherParameters(customDirectives =
Seq((group, scopedDirective.directive.key, values))
)
}
),
Nil
)
}
}
}
}
}
val kernelOptionsHandlers = LauncherInterpreter.kernelOptionsHandlers.addCustomHandler { key =>
customDirectiveGroups.find(_.matches(key)).map { group =>
new DirectiveHandler[HasKernelOptions] {
def name = s"custom group ${group.prefix}"
def description = s"custom group ${group.prefix}"
def usage = s"//> ${group.prefix}..."
def keys = Seq(key)
def handleValues(scopedDirective: ScopedDirective)
: Either[DirectiveException, ProcessedDirective[HasKernelOptions]] =
Right(ProcessedDirective(Some(HasKernelOptions.Ignore), Nil))
}
}
}

def execute(
code: String,
storeHistory: Boolean,
Expand All @@ -52,14 +110,14 @@ class LauncherInterpreter(
val path = Left(s"cell$lineCount0.sc")
val scopePath = ScopePath(Left("."), os.sub)
val maybeParamsUpdate =
LauncherInterpreter.launcherParametersHandlers.parse(code, path, scopePath)
launcherParametersHandlers.parse(code, path, scopePath)
.map { res =>
res
.flatMap(_.global.map(_.launcherParameters).toSeq)
.foldLeft(LauncherParameters())(_ + _)
}
val maybeKernelOptionsUpdate =
LauncherInterpreter.kernelOptionsHandlers.parse(code, path, scopePath)
kernelOptionsHandlers.parse(code, path, scopePath)
.flatMap { res =>
res
.flatMap(_.global.map(_.kernelOptions).toSeq)
Expand Down Expand Up @@ -169,4 +227,21 @@ object LauncherInterpreter {
fansi.Attrs.Empty
)
}

def computeScalaVersion(
params0: LauncherParameters,
options: LauncherOptions
): (String, String) = {

val requestedScalaVersion = params0.scala.map((_, "directive"))
.orElse(options.scala.map(_.trim).filter(_.nonEmpty).map((_, "command-line")))
.getOrElse((Properties.defaultScalaVersion, "default"))

requestedScalaVersion._1 match {
case "2.12" => (Properties.defaultScala212Version, requestedScalaVersion._2)
case "2" | "2.13" => (Properties.defaultScala213Version, requestedScalaVersion._2)
case "3" => (Properties.defaultScalaVersion, requestedScalaVersion._2)
case _ => requestedScalaVersion
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package almond.launcher

import almond.kernel.install.{Options => InstallOptions}
import almond.launcher.directives.CustomGroup
import caseapp._

import scala.cli.directivehandler.EitherSequence._
import scala.collection.mutable

// format: off
Expand Down Expand Up @@ -31,7 +33,8 @@ final case class LauncherOptions(
javaOpt: List[String] = Nil,
quiet: Option[Boolean] = None,
silentImports: Option[Boolean] = None,
useNotebookCoursierLogger: Option[Boolean] = None
useNotebookCoursierLogger: Option[Boolean] = None,
customDirectiveGroup: List[String] = Nil
) {
// format: on

Expand Down Expand Up @@ -65,6 +68,27 @@ final case class LauncherOptions(
}

def quiet0 = quiet.getOrElse(true)

def customDirectiveGroupsOrExit(): Seq[CustomGroup] = {
val maybeGroups = customDirectiveGroup
.map { input =>
input.split(":", 2) match {
case Array(prefix, command) => Right(CustomGroup(prefix, command))
case Array(_) =>
Left(s"Malformed custom directive group argument, expected 'prefix:command': '$input'")
}
}
.sequence

maybeGroups match {
case Left(errors) =>
for (err <- errors)
System.err.println(err)
sys.exit(1)
case Right(groups) =>
groups
}
}
}

object LauncherOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ object Properties {
lazy val version = prop("version")
lazy val commitHash = prop("commit-hash")

lazy val ammoniteVersion = prop("ammonite-version")

lazy val kernelMainClass = prop("kernel-main-class")
lazy val defaultScalaVersion = prop("default-scala-version")
lazy val defaultScala212Version = prop("default-scala212-version")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package almond.launcher.directives

case class CustomGroup(
prefix: String,
command: String
) {
private lazy val prefix0 = prefix + "."
def matches(key: String): Boolean =
key == prefix || key.startsWith(prefix0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package almond.launcher.directives

import scala.cli.directivehandler._

@DirectiveGroupName("Java command")
@DirectiveExamples("//> using javaCmd \"java\" \"-Dfoo=thing\"")
@DirectiveUsage(
"//> using javaCmd _args_",
"`//> using javaCmd` _args_"
)
@DirectiveDescription("Specify a java command to run the Scala version-specific kernel")
final case class JavaCommand(
javaCmd: Option[List[String]] = None
) extends HasLauncherParameters {
def launcherParameters = LauncherParameters(
javaCmd = javaCmd.filter(_.exists(_.nonEmpty))
)
}

object JavaCommand {
val handler: DirectiveHandler[JavaCommand] = DirectiveHandler.deriver[JavaCommand].derive
}
Loading

0 comments on commit da03dc1

Please sign in to comment.