diff --git a/build.sc b/build.sc index c8641795b..9bf01d8da 100644 --- a/build.sc +++ b/build.sc @@ -414,6 +414,7 @@ trait Launcher extends AlmondSimpleModule with BootstrapLauncher with PropertyFi } Seq( "kernel-main-class" -> mainClass, + "ammonite-version" -> Versions.ammonite, "default-scala212-version" -> ScalaVersions.scala212, "default-scala213-version" -> ScalaVersions.scala213, "default-scala-version" -> ScalaVersions.scala3Latest diff --git a/modules/scala/integration/src/test/scala/almond/integration/KernelTestsTwoStepStartup213.scala b/modules/scala/integration/src/test/scala/almond/integration/KernelTestsTwoStepStartup213.scala index 8171d02b4..7599d66b7 100644 --- a/modules/scala/integration/src/test/scala/almond/integration/KernelTestsTwoStepStartup213.scala +++ b/modules/scala/integration/src/test/scala/almond/integration/KernelTestsTwoStepStartup213.scala @@ -182,4 +182,71 @@ class KernelTestsTwoStepStartup213 extends KernelTestsDefinitions { } } + test("custom startup directives") { + + val tmpDir = os.temp.dir(prefix = "almond-custom-directives-test-") + + val handlerCode = + """//> using lib "com.lihaoyi::os-lib:0.9.1" + |//> using lib "com.lihaoyi::pprint:0.8.1" + |//> using lib "com.lihaoyi::upickle:3.1.0" + |//> using scala "2.13.11" + |//> using jvm "8" + | + |package handler + | + |object Handler { + | case class Entry(key: String, values: List[String]) + | implicit val entryCodec: upickle.default.ReadWriter[Entry] = upickle.default.macroRW + | + | def main(args: Array[String]): Unit = { + | assert(args.length == 1, "Usage: handle-spark-directives input.json") + | val inputEntries = upickle.default.read[List[Entry]](os.read.bytes(os.Path(args(0), os.pwd))) + | pprint.err.log(inputEntries) + | + | val version = inputEntries.find(_.key == "spark.version").flatMap(_.values.headOption).getOrElse("X.Y") + | + | val output = ujson.Obj( + | "javaCmd" -> ujson.Arr(Seq("java", s"-Dthe-version=$version").map(ujson.Str(_)): _*) + | ) + | + | println(output.render()) + | } + |} + |""".stripMargin + + val directivesHandler = tmpDir / "handle-spark-directives" + + os.write(tmpDir / "Handler.scala", handlerCode) + + os.proc( + Tests.java17Cmd, + "-jar", + Tests.scalaCliLauncher.toString, + "--power", + "package", + ".", + "-o", + directivesHandler + ).call(cwd = tmpDir) + + kernelLauncher.withKernel { implicit runner => + implicit val sessionId: SessionId = SessionId() + runner.withSession("--custom-directive-group", s"spark:$directivesHandler") { + implicit session => + execute( + """//> using spark.version "1.2.3" + |//> using spark + |""".stripMargin, + "" + ) + + execute( + """val version = sys.props.getOrElse("the-version", "nope")""", + """version: String = "1.2.3"""" + ) + } + } + } + } diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala b/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala index e9f3b5157..13f2067a8 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala @@ -41,16 +41,7 @@ object Launcher extends CaseApp[LauncherOptions] { logCtx: LoggerContext ): (os.proc, String, Option[String]) = { - val requestedScalaVersion = params0.scala - .orElse(options.scala.map(_.trim).filter(_.nonEmpty)) - .getOrElse(Properties.defaultScalaVersion) - - val scalaVersion = requestedScalaVersion match { - case "2.12" => Properties.defaultScala212Version - case "2" | "2.13" => Properties.defaultScala213Version - case "3" => Properties.defaultScalaVersion - case _ => requestedScalaVersion - } + val (scalaVersion, _) = LauncherInterpreter.computeScalaVersion(params0, options) def content(entries: Seq[(coursierapi.Artifact, File)]): ClassLoaderContent = { val entries0 = entries.map { @@ -170,9 +161,9 @@ object Launcher extends CaseApp[LauncherOptions] { ) val javaHome = os.Path(jvmManager.get(jvmId), os.pwd) val ext = if (scala.util.Properties.isWin) ".exe" else "" - (javaHome / "bin" / s"java$ext").toString + Seq((javaHome / "bin" / s"java$ext").toString) case None => - "java" + params0.javaCmd.getOrElse(Seq("java")) } val javaOptions = options.javaOpt ++ params0.javaOptions @@ -200,7 +191,7 @@ object Launcher extends CaseApp[LauncherOptions] { options.kernelOptions ) - (proc, requestedScalaVersion, jvmIdOpt) + (proc, scalaVersion, jvmIdOpt) } private def launchActualKernel(proc: os.proc): Unit = { @@ -372,7 +363,7 @@ object Launcher extends CaseApp[LauncherOptions] { interpreter.lineCount, options, firstMessageIdOpt.toSeq, - interpreter.params, + interpreter.params.processCustomDirectives(), interpreter.kernelOptions, outputHandlerOpt.getOrElse(OutputHandler.NopOutputHandler), logCtx diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherInterpreter.scala b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherInterpreter.scala index d373e2e6a..0ed4f76f1 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherInterpreter.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherInterpreter.scala @@ -12,6 +12,7 @@ import almond.protocol.KernelInfo import java.io.File import scala.cli.directivehandler._ +import scala.cli.directivehandler.DirectiveValueParser.DirectiveValueParserValueOps import scala.cli.directivehandler.EitherSequence._ class LauncherInterpreter( @@ -19,30 +20,87 @@ class LauncherInterpreter( options: LauncherOptions ) extends Interpreter { - def kernelInfo(): KernelInfo = + def kernelInfo(): KernelInfo = { + val (sv, svOrigin) = LauncherInterpreter.computeScalaVersion(params, options) KernelInfo( implementation = "scala", - implementation_version = "???", + implementation_version = Properties.version, language_info = KernelInfo.LanguageInfo( name = "scala", - version = "???", + version = Properties.version, mimetype = "text/x-scala", file_extension = ".sc", nbconvert_exporter = "script", codemirror_mode = Some("text/x-scala") ), banner = - s"""Almond ${"???"} - |Ammonite ${"???"} - |${"???"} - |Java ${"???"}""".stripMargin, // + + s"""Almond ${Properties.version} + |Ammonite ${Properties.ammoniteVersion} + |Scala $sv (from $svOrigin)""".stripMargin, // + // params.extraBannerOpt.fold("")("\n\n" + _), help_links = None // Some(params.extraLinks.toList).filter(_.nonEmpty) ) + } var kernelOptions = KernelOptions() var params = LauncherParameters() + val customDirectiveGroups = options.customDirectiveGroupsOrExit() + val launcherParametersHandlers = + LauncherInterpreter.launcherParametersHandlers.addCustomHandler { key => + customDirectiveGroups.find(_.matches(key)).map { group => + new DirectiveHandler[HasLauncherParameters] { + def name = s"custom group ${group.prefix}" + def description = s"custom group ${group.prefix}" + def usage = s"//> ${group.prefix}..." + + def keys = Seq(key) + def handleValues(scopedDirective: ScopedDirective) + : Either[DirectiveException, ProcessedDirective[HasLauncherParameters]] = { + assert(scopedDirective.directive.key == key) + val maybeValues = scopedDirective.directive.values + .filter(!_.isEmpty) + .map { value => + value.asString.toRight { + new MalformedDirectiveError( + s"Expected a string, got '${value.getRelatedASTNode.toString}'", + Seq(value.position(scopedDirective.maybePath)) + ) + } + } + .sequence + .left.map(CompositeDirectiveException(_)) + maybeValues.map { values => + ProcessedDirective( + Some( + new HasLauncherParameters { + def launcherParameters = + LauncherParameters(customDirectives = + Seq((group, scopedDirective.directive.key, values)) + ) + } + ), + Nil + ) + } + } + } + } + } + val kernelOptionsHandlers = LauncherInterpreter.kernelOptionsHandlers.addCustomHandler { key => + customDirectiveGroups.find(_.matches(key)).map { group => + new DirectiveHandler[HasKernelOptions] { + def name = s"custom group ${group.prefix}" + def description = s"custom group ${group.prefix}" + def usage = s"//> ${group.prefix}..." + def keys = Seq(key) + def handleValues(scopedDirective: ScopedDirective) + : Either[DirectiveException, ProcessedDirective[HasKernelOptions]] = + Right(ProcessedDirective(Some(HasKernelOptions.Ignore), Nil)) + } + } + } + def execute( code: String, storeHistory: Boolean, @@ -52,14 +110,14 @@ class LauncherInterpreter( val path = Left(s"cell$lineCount0.sc") val scopePath = ScopePath(Left("."), os.sub) val maybeParamsUpdate = - LauncherInterpreter.launcherParametersHandlers.parse(code, path, scopePath) + launcherParametersHandlers.parse(code, path, scopePath) .map { res => res .flatMap(_.global.map(_.launcherParameters).toSeq) .foldLeft(LauncherParameters())(_ + _) } val maybeKernelOptionsUpdate = - LauncherInterpreter.kernelOptionsHandlers.parse(code, path, scopePath) + kernelOptionsHandlers.parse(code, path, scopePath) .flatMap { res => res .flatMap(_.global.map(_.kernelOptions).toSeq) @@ -169,4 +227,21 @@ object LauncherInterpreter { fansi.Attrs.Empty ) } + + def computeScalaVersion( + params0: LauncherParameters, + options: LauncherOptions + ): (String, String) = { + + val requestedScalaVersion = params0.scala.map((_, "directive")) + .orElse(options.scala.map(_.trim).filter(_.nonEmpty).map((_, "command-line"))) + .getOrElse((Properties.defaultScalaVersion, "default")) + + requestedScalaVersion._1 match { + case "2.12" => (Properties.defaultScala212Version, requestedScalaVersion._2) + case "2" | "2.13" => (Properties.defaultScala213Version, requestedScalaVersion._2) + case "3" => (Properties.defaultScalaVersion, requestedScalaVersion._2) + case _ => requestedScalaVersion + } + } } diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala index 30cdb1d96..02e557512 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala @@ -1,8 +1,10 @@ package almond.launcher import almond.kernel.install.{Options => InstallOptions} +import almond.launcher.directives.CustomGroup import caseapp._ +import scala.cli.directivehandler.EitherSequence._ import scala.collection.mutable // format: off @@ -31,7 +33,8 @@ final case class LauncherOptions( javaOpt: List[String] = Nil, quiet: Option[Boolean] = None, silentImports: Option[Boolean] = None, - useNotebookCoursierLogger: Option[Boolean] = None + useNotebookCoursierLogger: Option[Boolean] = None, + customDirectiveGroup: List[String] = Nil ) { // format: on @@ -65,6 +68,27 @@ final case class LauncherOptions( } def quiet0 = quiet.getOrElse(true) + + def customDirectiveGroupsOrExit(): Seq[CustomGroup] = { + val maybeGroups = customDirectiveGroup + .map { input => + input.split(":", 2) match { + case Array(prefix, command) => Right(CustomGroup(prefix, command)) + case Array(_) => + Left(s"Malformed custom directive group argument, expected 'prefix:command': '$input'") + } + } + .sequence + + maybeGroups match { + case Left(errors) => + for (err <- errors) + System.err.println(err) + sys.exit(1) + case Right(groups) => + groups + } + } } object LauncherOptions { diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/Properties.scala b/modules/scala/launcher/src/main/scala/almond/launcher/Properties.scala index 801baa8cd..f4c48199e 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/Properties.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/Properties.scala @@ -22,6 +22,8 @@ object Properties { lazy val version = prop("version") lazy val commitHash = prop("commit-hash") + lazy val ammoniteVersion = prop("ammonite-version") + lazy val kernelMainClass = prop("kernel-main-class") lazy val defaultScalaVersion = prop("default-scala-version") lazy val defaultScala212Version = prop("default-scala212-version") diff --git a/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/CustomGroup.scala b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/CustomGroup.scala new file mode 100644 index 000000000..fd751d347 --- /dev/null +++ b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/CustomGroup.scala @@ -0,0 +1,10 @@ +package almond.launcher.directives + +case class CustomGroup( + prefix: String, + command: String +) { + private lazy val prefix0 = prefix + "." + def matches(key: String): Boolean = + key == prefix || key.startsWith(prefix0) +} diff --git a/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/JavaCommand.scala b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/JavaCommand.scala new file mode 100644 index 000000000..816927dbb --- /dev/null +++ b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/JavaCommand.scala @@ -0,0 +1,22 @@ +package almond.launcher.directives + +import scala.cli.directivehandler._ + +@DirectiveGroupName("Java command") +@DirectiveExamples("//> using javaCmd \"java\" \"-Dfoo=thing\"") +@DirectiveUsage( + "//> using javaCmd _args_", + "`//> using javaCmd` _args_" +) +@DirectiveDescription("Specify a java command to run the Scala version-specific kernel") +final case class JavaCommand( + javaCmd: Option[List[String]] = None +) extends HasLauncherParameters { + def launcherParameters = LauncherParameters( + javaCmd = javaCmd.filter(_.exists(_.nonEmpty)) + ) +} + +object JavaCommand { + val handler: DirectiveHandler[JavaCommand] = DirectiveHandler.deriver[JavaCommand].derive +} diff --git a/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/LauncherParameters.scala b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/LauncherParameters.scala index 41b4437a2..4f44f1e4b 100644 --- a/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/LauncherParameters.scala +++ b/modules/scala/shared-directives/src/main/scala/almond/launcher/directives/LauncherParameters.scala @@ -1,18 +1,55 @@ package almond.launcher.directives +import com.github.plokhotnyuk.jsoniter_scala.core._ +import com.github.plokhotnyuk.jsoniter_scala.macros._ + import scala.cli.directivehandler.{DirectiveHandler, DirectiveHandlers} final case class LauncherParameters( jvm: Option[String] = None, javaOptions: Seq[String] = Nil, - scala: Option[String] = None + scala: Option[String] = None, + javaCmd: Option[Seq[String]] = None, + customDirectives: Seq[(CustomGroup, String, Seq[String])] = Nil ) { def +(other: LauncherParameters): LauncherParameters = LauncherParameters( jvm.orElse(other.jvm), javaOptions ++ other.javaOptions, - scala.orElse(other.scala) + scala.orElse(other.scala), + javaCmd.orElse(other.javaCmd), + customDirectives = customDirectives ++ other.customDirectives ) + + def processCustomDirectives(): LauncherParameters = { + + import LauncherParameters.{AsJson, Entry, entriesCodec} + + var tmpFile0: os.Path = null + lazy val tmpFile = { + tmpFile0 = os.temp(prefix = "almond-launcher-params-", suffix = ".json") + tmpFile0 + } + + try + customDirectives + .map { + case (group, key, values) => + val entries = List(Entry(key, values.toList)) + os.write.over(tmpFile, writeToArray(entries)(entriesCodec)) + val res = os.proc(group.command, tmpFile) + .call(stdin = os.Inherit, check = false) + if (res.exitCode != 0) + sys.error( + s"Command ${group.command} for custom directives ${group.prefix} exited with code ${res.exitCode}" + ) + readFromArray(res.out.bytes)(AsJson.codec).params + } + .foldLeft(this)(_ + _) + finally + if (tmpFile0 != null) + os.remove(tmpFile0) + } } object LauncherParameters { @@ -21,8 +58,39 @@ object LauncherParameters { Seq[DirectiveHandler[HasLauncherParameters]]( JavaOptions.handler, Jvm.handler, - ScalaVersion.handler + ScalaVersion.handler, + JavaCommand.handler ) ) + private case class Entry(key: String, values: List[String]) + private val entriesCodec: JsonValueCodec[List[Entry]] = + JsonCodecMaker.makeWithRequiredCollectionFields + + private final case class AsJson( + jvm: Option[String] = None, + javaOptions: Seq[String] = Nil, + scala: Option[String] = None, + javaCmd: Option[Seq[String]] = None + ) { + def params: LauncherParameters = + LauncherParameters( + jvm = jvm, + javaOptions = javaOptions, + scala = scala, + javaCmd = javaCmd + ) + } + + private object AsJson { + val codec: JsonValueCodec[AsJson] = JsonCodecMaker.make + def from(params: LauncherParameters): AsJson = + AsJson( + jvm = params.jvm, + javaOptions = params.javaOptions, + scala = params.scala, + javaCmd = params.javaCmd + ) + } + } diff --git a/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala b/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala index 532ab0ecf..2b07d4aff 100644 --- a/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala +++ b/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala @@ -381,7 +381,7 @@ object Tests { } } - private def java17Cmd(): String = { + lazy val java17Cmd: String = { val isAtLeastJava17 = scala.util.Try(sys.props("java.version").takeWhile(_.isDigit).toInt).toOption.exists(_ >= 17) val javaHome = @@ -391,10 +391,10 @@ object Tests { new File(javaHome, "bin/java" + ext).toString } - private def scalaCliLauncher(): File = + lazy val scalaCliLauncher: File = coursierapi.Cache.create() .get(coursierapi.Artifact.of( - "https://github.com/VirtusLab/scala-cli/releases/download/v1.0.0-RC1/scala-cli" + "https://github.com/VirtusLab/scala-cli/releases/download/v1.0.1/scala-cli" )) def toreeAddJarCustomProtocol(scalaVersion: String)(implicit @@ -427,9 +427,9 @@ object Tests { os.write(tmpDir / "FooURLConnection.scala", code) val extraCp = os.proc( - java17Cmd(), + java17Cmd, "-jar", - scalaCliLauncher().toString, + scalaCliLauncher.toString, "--power", "compile", "--print-class-path", diff --git a/project/deps.sc b/project/deps.sc index 2657d77df..707132f47 100644 --- a/project/deps.sc +++ b/project/deps.sc @@ -50,7 +50,7 @@ object Deps { def coursierApi = ivy"io.get-coursier:interface:1.0.18" def coursierLauncher = ivy"io.get-coursier:coursier-launcher_2.13:${Versions.coursier}" def dependencyInterface = ivy"io.get-coursier::dependency-interface:0.2.3" - def directiveHandler = ivy"io.github.alexarchambault.scala-cli::directive-handler:0.1.2" + def directiveHandler = ivy"io.github.alexarchambault.scala-cli::directive-handler:0.1.3" def expecty = ivy"com.eed3si9n.expecty::expecty:0.16.0" def fansi = ivy"com.lihaoyi::fansi:0.4.0" def fs2(sv: String) =