diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 73e2ffdf007d3..d781a65878b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( Aggregation :: JoinSelection :: InMemoryScans :: + Scripts :: BasicOperators :: Nil) override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index fafb91967086f..2b1d4279a15c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.script.{ScriptTransformationExec, ScriptTransformIOSchema} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery @@ -313,6 +314,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Scripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child, ioschema) => + ScriptTransformationExec( + input, + script, + output, + planLater(child), + ScriptTransformIOSchema(ioschema) + ) :: Nil + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/script/ScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/script/ScriptTransformationExec.scala new file mode 100644 index 0000000000000..0d79f4a59aeb8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/script/ScriptTransformationExec.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.script + +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +private[sql] +case class ScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformIOSchema) + extends UnaryExecNode with ScriptTransformBase { + + override def producedAttributes: AttributeSet = outputSet -- inputSet + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = + execute(sqlContext, child, schema) + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration) : Iterator[InternalRow] = { + + val (proc, inputStream, outputStream, stderrBuffer, outputProjection) = + init(input, script, child) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + input.map(_.dataType), + outputProjection, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + val reader = createReader(inputStream) + + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { + var curLine: String = null + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) + val fieldDelimiter = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread.exception, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread.exception, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + if (!ioschema.isSchemaLess) { + new GenericInternalRow( + prevLine.split(fieldDelimiter).map(CatalystTypeConverters.convertToCatalyst)) + } else { + new GenericInternalRow( + prevLine.split(fieldDelimiter, 2).map(CatalystTypeConverters.convertToCatalyst)) + } + } + } + + writerThread.start() + outputIterator + } +} + +private[sql] trait ScriptTransformBase extends Serializable with Logging { + + def init( + input: Seq[Expression], + script: String, + child: SparkPlan + ): (Process, InputStream, OutputStream, CircularBuffer, InterpretedProjection) = { + + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + (proc, inputStream, outputStream, stderrBuffer, outputProjection) + } + + def execute(sqlContext: SQLContext, + child: SparkPlan, + schema: StructType): RDD[InternalRow] = { + val broadcastedHadoopConf = + new SerializableConfiguration(sqlContext.sessionState.newHadoopConf()) + + child.execute().mapPartitions { iter => + if (iter.hasNext) { + val proj = UnsafeProjection.create(schema) + processIterator(iter, broadcastedHadoopConf.value).map(proj) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } + + def checkFailureAndPropagate( + writerException: Option[Throwable], + cause: Throwable = null, + proc: Process, + stderrBuffer: CircularBuffer): Unit = { + if (writerException.isDefined) { + throw writerException.get + } + + // Checks if the proc is still alive (incase the command ran was bad) + // The ideal way to do this is to use Java 8's Process#isAlive() + // but it cannot be used because Spark still supports Java 7. + // Following is a workaround used to check if a process is alive in Java 7 + // TODO: Once builds are switched to Java 8, this can be changed + try { + val exitCode = proc.exitValue() + if (exitCode != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + throw new SparkException(s"Subprocess exited with status $exitCode. " + + s"Error: ${stderrBuffer.toString}", cause) + } + } catch { + case _: IllegalThreadStateException => + // This means that the process is still alive. Move ahead + } + } + + def createReader(inputStream: InputStream): BufferedReader = + new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + + def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration) : Iterator[InternalRow] +} + +private[sql] class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + outputProjection: Projection, + ioschema: ScriptTransformIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging with Serializable { + + setDaemon(true) + + @volatile protected var _exception: Throwable = null + + protected val lineDelimiter = ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + protected val fieldDelimiter = ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD") + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + protected def init(): Unit = { + TaskContext.setTaskContext(taskContext) + } + + protected def processRow(row: InternalRow, numColumns: Int): Unit = { + val data = if (numColumns == 0) { + lineDelimiter + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema.head)) + var i = 1 + while (i < numColumns) { + sb.append(fieldDelimiter) + val columnType = inputSchema(i) + val fieldValue = row.get(i, columnType) + val fieldStringValue = columnType match { + case _: DateType => + DateTimeUtils.dateToString(fieldValue.asInstanceOf[SQLDate]) + case _: TimestampType => + DateTimeUtils.timestampToString(fieldValue.asInstanceOf[SQLTimestamp]) + case _ => + fieldValue.toString + } + sb.append(fieldStringValue) + i += 1 + } + sb.append(lineDelimiter) + sb.toString() + } + outputStream.write(data.getBytes(StandardCharsets.UTF_8)) + } + + override def run(): Unit = Utils.logUncaughtExceptions { + init() + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + val numColumns = inputSchema.length + try { + iter.map(outputProjection).foreach(row => processRow(row, numColumns)) + threwException = false + } catch { + case t: Throwable => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = t + proc.destroy() + throw t + } finally { + try { + Utils.tryLogNonFatalError(outputStream.close()) + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } + } + } +} + +private[sql] +object ScriptTransformIOSchema { + def apply(input: ScriptInputOutputSchema): ScriptTransformIOSchema = { + new ScriptTransformIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.schemaLess) + } +} + +/** + * The wrapper class of Hive input and output schema properties + * + * @param inputRowFormat Contains delimiter information for the script's output + * @param outputRowFormat Contains delimiter information for the script's input + * @param schemaLess When set to true, script's output is tokenized as a key-value pair + * else it would be tokenized to extract multiple columns. + */ +private[sql] class ScriptTransformIOSchema ( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + schemaLess: Boolean) extends Serializable { + + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) + + val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + + def isSchemaLess: Boolean = schemaLess +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/script/ScriptTransformationExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/script/ScriptTransformationExecSuite.scala new file mode 100644 index 0000000000000..d53bf235d677b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/script/ScriptTransformationExecSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.script + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class ScriptTransformationExecSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits._ + + private val ioSchema = new ScriptTransformIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + schemaLess = false + ) + + test("cat") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = ioSchema + ), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = ioSchema + ), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("SPARK-14400 script transformation should fail for bad script command") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + + val e = intercept[SparkException] { + val plan = + new ScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = ioSchema) + SparkPlanTest.executePlan(plan, sqlContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + } +} + +private [sql] +case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 9fd03ef8ba037..04efadf6ee0b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -93,7 +93,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) SpecialLimits, InMemoryScans, HiveTableScans, - Scripts, + HiveScripts, Aggregation, JoinSelection, BasicOperators diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9a7111aa3b8b0..ae99f241b0402 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, PreprocessTableInsertion} import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.execution.script.{HiveScriptIOSchema, HiveScriptTransformationExec} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} - /** * Determine the serde/format of the Hive serde table, according to the storage properties. */ @@ -117,11 +117,11 @@ private[hive] trait HiveStrategies { val sparkSession: SparkSession - object Scripts extends Strategy { + object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => val hiveIoSchema = HiveScriptIOSchema(ioschema) - ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil + HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExec.scala new file mode 100644 index 0000000000000..9583538a95f89 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExec.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution.script + +import java.io._ +import java.util.Properties +import javax.annotation.Nullable + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.AbstractSerDe +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.io.Writable + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.script._ +import org.apache.spark.sql.hive.HiveInspectors +import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{CircularBuffer, Utils} + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class HiveScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: HiveScriptIOSchema) + extends UnaryExecNode with ScriptTransformBase { + + override def producedAttributes: AttributeSet = outputSet -- inputSet + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (proc, inputStream, outputStream, stderrBuffer, outputProjection) = + init(input, script, child) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new HiveScriptTransformationWriterThread( + inputIterator, + input.map(_.dataType), + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } + + val reader = createReader(inputStream) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + var curLine: String = null + val scriptOutputStream = new DataInputStream(inputStream) + val fieldDelimiter = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, hadoopConf).orNull + + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) + + @transient + lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) + + override def hasNext: Boolean = { + try { + if (outputSerde == null) { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread.exception, null, proc, stderrBuffer) + return false + } + } + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + checkFailureAndPropagate(writerThread.exception, null, proc, stderrBuffer) + return false + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + } catch { + case _: EOFException => + // This means that the stdout of `proc` (ie. TRANSFORM process) has exhausted. + // Ideally the proc should *not* be alive at this point but + // there can be a lag between EOF being written out and the process + // being terminated. So explicitly waiting for the process to be done. + proc.waitFor() + checkFailureAndPropagate(writerThread.exception, null, proc, stderrBuffer) + return false + } + } + } + + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread.exception, null, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + if (outputSerde == null) { + val prevLine = curLine + curLine = reader.readLine() + if (!ioschema.schemaLess) { + new GenericInternalRow( + prevLine.split(fieldDelimiter).map(CatalystTypeConverters.convertToCatalyst)) + } else { + new GenericInternalRow( + prevLine.split(fieldDelimiter, 2).map(CatalystTypeConverters.convertToCatalyst)) + } + } else { + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) + } else { + unwrappers(i)(dataList.get(i), mutableRow, i) + } + i += 1 + } + mutableRow + } + } + } + + writerThread.start() + outputIterator + } + + protected override def doExecute(): RDD[InternalRow] = + execute(sqlContext, child, schema) +} + +class HiveScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration + ) extends ScriptTransformationWriterThread( + iter, + inputSchema, + outputProjection, + ioschema, + outputStream, + proc, + stderrBuffer, + taskContext, + conf + ) with Serializable { + + var dataOutputStream: DataOutputStream = null + var scriptInputWriter: Option[RecordWriter] = None + + override def init(): Unit = { + super.init() + dataOutputStream = new DataOutputStream(outputStream) + scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf) + } + + override def processRow(row: InternalRow, numColumns: Int): Unit = { + if (inputSerde == null) { + super.processRow(row, numColumns) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + + if (scriptInputWriter.isDefined) { + scriptInputWriter.get.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } + } + } +} + +object HiveScriptIOSchema { + def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = { + HiveScriptIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } +} + +/** + * The wrapper class of Hive input and output schema properties + */ +case class HiveScriptIOSchema ( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) + extends ScriptTransformIOSchema( + inputRowFormat, + outputRowFormat, + schemaLess + ) with HiveInspectors with Serializable { + + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) + } + } + + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } + } + + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) + (columns, columnTypes) + } + + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { + + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] + + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + + val properties = new Properties() + properties.putAll(propsMap.asJava) + serde.initialize(null, properties) + + serde + } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExecSuite.scala similarity index 71% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExecSuite.scala index 5318b4650b01f..dbbcc9521db63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/script/HiveScriptTransformationExecSuite.scala @@ -15,21 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.hive.execution.script import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException -import org.apache.spark.{SparkException, TaskContext, TestUtils} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.script.ExceptionInjectingOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { +class HiveScriptTransformationExecSuite extends SparkPlanTest with TestHiveSingleton { import spark.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( @@ -50,12 +48,10 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { ) test("cat without SerDe") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -66,12 +62,10 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } test("cat with LazySimpleSerDe") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -82,13 +76,11 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } test("script transformation should not swallow errors from upstream operators (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -101,13 +93,11 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } test("script transformation should not swallow errors from upstream operators (with serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -120,13 +110,11 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } test("SPARK-14400 script transformation should fail for bad script command") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new ScriptTransformationExec( + new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -137,17 +125,3 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { assert(e.getMessage.contains("Subprocess exited with status")) } } - -private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = { - child.execute().map { x => - assert(TaskContext.get() != null) // Make sure that TaskContext is defined. - Thread.sleep(1000) // This sleep gives the external process time to start. - throw new IllegalArgumentException("intentional exception") - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning -}