diff --git a/core/src/main/scala/org/apache/spark/deploy/RedirectConsolePlugin.scala b/core/src/main/scala/org/apache/spark/deploy/RedirectConsolePlugin.scala new file mode 100644 index 000000000000..cc1995a264fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RedirectConsolePlugin.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.util.{Collections, Map => JMap} + +import org.apache.spark.SparkContext +import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} +import org.apache.spark.internal.{Logging, SparkLoggerFactory} +import org.apache.spark.internal.config._ + +/** + * A built-in plugin to allow redirecting stdout/stderr to logging system (SLF4J). + */ +class RedirectConsolePlugin extends SparkPlugin { + override def driverPlugin(): DriverPlugin = new DriverRedirectConsolePlugin() + + override def executorPlugin(): ExecutorPlugin = new ExecRedirectConsolePlugin() +} + +object RedirectConsolePlugin { + + def redirectStdoutToLog(): Unit = { + val stdoutLogger = SparkLoggerFactory.getLogger("stdout") + System.setOut(new LoggingPrintStream(stdoutLogger.info)) + } + + def redirectStderrToLog(): Unit = { + val stderrLogger = SparkLoggerFactory.getLogger("stderr") + System.setErr(new LoggingPrintStream(stderrLogger.error)) + } +} + +class DriverRedirectConsolePlugin extends DriverPlugin with Logging { + + override def init(sc: SparkContext, ctx: PluginContext): JMap[String, String] = { + val outputs = sc.conf.get(DRIVER_REDIRECT_CONSOLE_OUTPUTS) + if (outputs.contains("stdout")) { + logInfo("Redirect driver's stdout to logging system.") + RedirectConsolePlugin.redirectStdoutToLog() + } + if (outputs.contains("stderr")) { + logInfo("Redirect driver's stderr to logging system.") + RedirectConsolePlugin.redirectStderrToLog() + } + Collections.emptyMap + } +} + +class ExecRedirectConsolePlugin extends ExecutorPlugin with Logging { + + override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = { + val outputs = ctx.conf.get(EXEC_REDIRECT_CONSOLE_OUTPUTS) + if (outputs.contains("stdout")) { + logInfo("Redirect executor's stdout to logging system.") + RedirectConsolePlugin.redirectStdoutToLog() + } + if (outputs.contains("stderr")) { + logInfo("Redirect executor's stderr to logging system.") + RedirectConsolePlugin.redirectStderrToLog() + } + } +} + +private[spark] class LoggingPrintStream(redirect: String => Unit) + extends PrintStream(new LineBuffer(4 * 1024 * 1024)) { + + override def write(b: Int): Unit = { + super.write(b) + tryLogCurrentLine() + } + + override def write(buf: Array[Byte], off: Int, len: Int): Unit = { + super.write(buf, off, len) + tryLogCurrentLine() + } + + private def tryLogCurrentLine(): Unit = this.synchronized { + out.asInstanceOf[LineBuffer].tryGenerateContext.foreach { logContext => + redirect(logContext) + } + } +} + +/** + * Cache bytes before line ending. When current line is ended or the bytes size reaches the + * threshold, it can generate the line. + */ +private[spark] object LineBuffer { + private val LF_BYTES = System.lineSeparator.getBytes + private val LF_LENGTH = LF_BYTES.length +} + +private[spark] class LineBuffer(lineMaxBytes: Long) extends ByteArrayOutputStream { + + import LineBuffer._ + + def tryGenerateContext: Option[String] = + if (isLineEnded) { + try Some(new String(buf, 0, count - LF_LENGTH)) finally reset() + } else if (count >= lineMaxBytes) { + try Some(new String(buf, 0, count)) finally reset() + } else { + None + } + + private def isLineEnded: Boolean = { + if (count < LF_LENGTH) return false + // fast return in UNIX-like OS when LF is single char '\n' + if (LF_LENGTH == 1) return LF_BYTES(0) == buf(count - 1) + + var i = 0 + do { + if (LF_BYTES(i) != buf(count - LF_LENGTH + i)) { + return false + } + i = i + 1 + } while (i < LF_LENGTH) + true + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7cb3d068b676..2ff0a8cf3646 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2838,4 +2838,30 @@ package object config { .checkValues(Set("connect", "classic")) .createWithDefault( if (sys.env.get("SPARK_CONNECT_MODE").contains("1")) "connect" else "classic") + + private[spark] val DRIVER_REDIRECT_CONSOLE_OUTPUTS = + ConfigBuilder("spark.driver.log.redirectConsoleOutputs") + .doc("Comma-separated list of the console output kind for driver that needs to redirect " + + "to logging system. Supported values are `stdout`, `stderr`. It only takes affect when " + + s"`${PLUGINS.key}` is configured with `org.apache.spark.deploy.RedirectConsolePlugin`.") + .version("4.1.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .toSequence + .checkValue(v => v.forall(Set("stdout", "stderr").contains), + "The value only can be one or more of 'stdout, stderr'.") + .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val EXEC_REDIRECT_CONSOLE_OUTPUTS = + ConfigBuilder("spark.executor.log.redirectConsoleOutputs") + .doc("Comma-separated list of the console output kind for executor that needs to redirect " + + "to logging system. Supported values are `stdout`, `stderr`. It only takes affect when " + + s"`${PLUGINS.key}` is configured with `org.apache.spark.deploy.RedirectConsolePlugin`.") + .version("4.1.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .toSequence + .checkValue(v => v.forall(Set("stdout", "stderr").contains), + "The value only can be one or more of 'stdout, stderr'.") + .createWithDefault(Seq("stdout", "stderr")) } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index e95eeddbdace..980cd6e541a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -38,6 +38,8 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { private val updatePeriodMSec = sc.conf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) // Delay to show up a progress bar, in milliseconds private val firstDelayMSec = 500L + // Get the stderr (which is console for spark-shell) before installing RedirectConsolePlugin + private val console = System.err // The width of terminal private val TerminalWidth = sys.env.getOrElse("COLUMNS", "80").toInt @@ -92,7 +94,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { // only refresh if it's changed OR after 1 minute (or the ssh connection will be closed // after idle some time) if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) { - System.err.print(s"$CR$bar$CR") + console.print(s"$CR$bar$CR") lastUpdateTime = now } lastProgressBar = bar @@ -103,7 +105,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { */ private def clear(): Unit = { if (!lastProgressBar.isEmpty) { - System.err.printf(s"$CR${" " * TerminalWidth}$CR") + console.printf(s"$CR${" " * TerminalWidth}$CR") lastProgressBar = "" } }