com.fasterxml.jackson
${spark.shade.packageName}.com.fasterxml.jackson
diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
index 6643a8f361cdc..d430d8c5fb35a 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
@@ -26,7 +26,6 @@ function getThreadDumpEnabled() {
}
function formatStatus(status, type) {
- if (type !== 'display') return status;
if (status) {
return "Active"
} else {
@@ -417,7 +416,6 @@ $(document).ready(function () {
},
{data: 'hostPort'},
{data: 'isActive', render: function (data, type, row) {
- if (type !== 'display') return data;
if (row.isBlacklisted) return "Blacklisted";
else return formatStatus (data, type);
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 7f7921d56f49e..e193ed222e228 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -278,4 +278,13 @@ package object config {
"spark.io.compression.codec.")
.booleanConf
.createWithDefault(false)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockThreshold")
+ .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " +
+ "record the size accurately if it's above this config. This helps to prevent OOM by " +
+ "avoiding underestimating shuffle block size when fetch shuffle blocks.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(100 * 1024 * 1024)
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index b2e9a97129f08..048e0d0186594 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,8 +19,13 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.roaringbitmap.RoaringBitmap
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus(
}
/**
- * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger
+ * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks,
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
- * @param avgSize average size of the non-empty blocks
+ * @param avgSize average size of the non-empty and non-huge blocks
+ * @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
- private[this] var avgSize: Long)
+ private[this] var avgSize: Long,
+ @transient private var hugeBlockSizes: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
- require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = {
+ assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
0
} else {
- avgSize
+ hugeBlockSizes.get(reduceId) match {
+ case Some(size) => MapStatus.decompressSize(size)
+ case None => avgSize
+ }
}
}
@@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private (
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
+ out.writeInt(hugeBlockSizes.size)
+ hugeBlockSizes.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private (
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
+ val count = in.readInt()
+ val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until count).foreach { _ =>
+ val block = in.readInt()
+ val size = in.readByte()
+ hugeBlockSizesArray += Tuple2(block, size)
+ }
+ hugeBlockSizes = hugeBlockSizesArray.toMap
}
}
@@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus {
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
while (i < totalNumBlocks) {
- var size = uncompressedSizes(i)
+ val size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
- totalSize += size
+ // Huge blocks are not included in the calculation for average size, thus size for smaller
+ // blocks is more accurate.
+ if (size < threshold) {
+ totalSize += size
+ } else {
+ hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
+ }
} else {
emptyBlocks.add(i)
}
@@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
+ hugeBlockSizesArray.toMap)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index a0fd29c22ddca..cce7a7611b420 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -631,7 +631,8 @@ private[ui] class JobPagedTable(
{if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"}
|
- {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks,
+ {UIUtils.makeProgressBar(started = job.numActiveTasks,
+ completed = job.completedIndices.size,
failed = job.numFailedTasks, skipped = job.numSkippedTasks,
reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)}
|
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 7370f9feb68cd..1b10feb36e439 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -423,6 +423,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
jobData.numActiveTasks -= 1
taskEnd.reason match {
case Success =>
+ jobData.completedIndices.add((taskEnd.stageId, info.index))
jobData.numCompletedTasks += 1
case kill: TaskKilled =>
jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated(
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 8d280bc00c3b3..048c4ad0146e2 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -62,6 +62,7 @@ private[spark] object UIData {
var numTasks: Int = 0,
var numActiveTasks: Int = 0,
var numCompletedTasks: Int = 0,
+ var completedIndices: OpenHashSet[(Int, Int)] = new OpenHashSet[(Int, Int)](),
var numSkippedTasks: Int = 0,
var numFailedTasks: Int = 0,
var reasonToNumKilled: Map[String, Int] = Map.empty,
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 7a897c2b4698f..c0126e41ff7fa 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -38,6 +38,10 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
override def beforeAll() {
super.beforeAll()
+ // Once 'spark.local.dir' is set, it is cached. Unless this is manually cleared
+ // before/after a test, it could return the same directory even if this property
+ // is configured.
+ Utils.clearLocalRootDirs()
conf.set("spark.shuffle.manager", "sort")
}
@@ -50,6 +54,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
override def afterEach(): Unit = {
try {
Utils.deleteRecursively(tempDir)
+ Utils.clearLocalRootDirs()
} finally {
super.afterEach()
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 759d52fca5ce1..3ec37f674c77b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -17,11 +17,15 @@
package org.apache.spark.scheduler
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+
import scala.util.Random
+import org.mockito.Mockito._
import org.roaringbitmap.RoaringBitmap
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.BlockManagerId
@@ -128,4 +132,26 @@ class MapStatusSuite extends SparkFunSuite {
assert(size1 === size2)
assert(!success)
}
+
+ test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " +
+ "underestimated.") {
+ val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000")
+ val env = mock(classOf[SparkEnv])
+ doReturn(conf).when(env).conf
+ SparkEnv.set(env)
+ // Value of element in sizes is equal to the corresponding index.
+ val sizes = (0L to 2000L).toArray
+ val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+ val arrayStream = new ByteArrayOutputStream(102400)
+ val objectOutputStream = new ObjectOutputStream(arrayStream)
+ assert(status1.isInstanceOf[HighlyCompressedMapStatus])
+ objectOutputStream.writeObject(status1)
+ objectOutputStream.flush()
+ val array = arrayStream.toByteArray
+ val objectInput = new ObjectInputStream(new ByteArrayInputStream(array))
+ val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus]
+ (1001 to 2000).foreach {
+ case part => assert(status2.getSizeForBlock(part) >= sizes(part))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index bdd148875e38a..267c8dc1bd750 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -320,12 +320,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
eventually(timeout(5 seconds), interval(50 milliseconds)) {
goToUi(sc, "/jobs")
find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)")
- // Ideally, the following test would pass, but currently we overcount completed tasks
- // if task recomputations occur:
- // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)")
- // Instead, we guarantee that the total number of tasks is always correct, while the number
- // of completed tasks may be higher:
- find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)")
+ find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)")
}
val jobJson = getJson(sc.ui.get, "jobs")
(jobJson \ "numTasks").extract[Int]should be (2)
diff --git a/docs/configuration.md b/docs/configuration.md
index 1d8d963016c71..a6b6d5dfa5f95 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -612,6 +612,15 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.codec.
+
+ spark.shuffle.accurateBlockThreshold |
+ 100 * 1024 * 1024 |
+
+ When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the
+ size accurately if it's above this config. This helps to prevent OOM by avoiding
+ underestimating shuffle block size when fetch shuffle blocks.
+ |
+
spark.io.encryption.enabled |
false |
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 362e883e55e83..fb4621389ab92 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -71,21 +71,24 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4
The list below highlights some of the new features and enhancements added to MLlib in the `2.2`
release of Spark:
-* `ALS` methods for _top-k_ recommendations for all users or items, matching the functionality
- in `mllib` ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). Performance
- was also improved for both `ml` and `mllib`
+* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all
+ users or items, matching the functionality in `mllib`
+ ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)).
+ Performance was also improved for both `ml` and `mllib`
([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and
[SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587))
-* `Correlation` and `ChiSquareTest` stats functions for `DataFrames`
+* [`Correlation`](ml-statistics.html#correlation) and
+ [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames`
([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and
[SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635))
-* `FPGrowth` algorithm for frequent pattern mining
+* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining
([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503))
* `GLM` now supports the full `Tweedie` family
([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929))
-* `Imputer` feature transformer to impute missing values in a dataset
+* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset
([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568))
-* `LinearSVC` for linear Support Vector Machine classification
+* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine)
+ for linear Support Vector Machine classification
([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709))
* Logistic regression now supports constraints on the coefficients during training
([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047))
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index 7cf5b7379503f..137ef74843da5 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -68,10 +68,12 @@ public List buildCommand(Map env)
case "org.apache.spark.executor.CoarseGrainedExecutorBackend":
javaOptsKeys.add("SPARK_EXECUTOR_OPTS");
memKey = "SPARK_EXECUTOR_MEMORY";
+ extraClassPath = getenv("SPARK_EXECUTOR_CLASSPATH");
break;
case "org.apache.spark.executor.MesosExecutorBackend":
javaOptsKeys.add("SPARK_EXECUTOR_OPTS");
memKey = "SPARK_EXECUTOR_MEMORY";
+ extraClassPath = getenv("SPARK_EXECUTOR_CLASSPATH");
break;
case "org.apache.spark.deploy.mesos.MesosClusterDispatcher":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
new file mode 100644
index 0000000000000..7f59825504d8e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.r.RWrapperUtils._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class DecisionTreeClassifierWrapper private (
+ val pipeline: PipelineModel,
+ val formula: String,
+ val features: Array[String]) extends MLWritable {
+
+ import DecisionTreeClassifierWrapper._
+
+ private val dtcModel: DecisionTreeClassificationModel =
+ pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
+
+ lazy val numFeatures: Int = dtcModel.numFeatures
+ lazy val featureImportances: Vector = dtcModel.featureImportances
+ lazy val maxDepth: Int = dtcModel.getMaxDepth
+
+ def summary: String = dtcModel.toDebugString
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(dtcModel.getFeaturesCol)
+ .drop(dtcModel.getLabelCol)
+ }
+
+ override def write: MLWriter = new
+ DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this)
+}
+
+private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] {
+
+ val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+ val PREDICTED_LABEL_COL = "prediction"
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ formula: String,
+ maxDepth: Int,
+ maxBins: Int,
+ impurity: String,
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ checkpointInterval: Int,
+ seed: String,
+ maxMemoryInMB: Int,
+ cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ .setForceIndexLabel(true)
+ checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get labels and feature names from output schema
+ val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
+
+ // assemble and fit the pipeline
+ val dtc = new DecisionTreeClassifier()
+ .setMaxDepth(maxDepth)
+ .setMaxBins(maxBins)
+ .setImpurity(impurity)
+ .setMinInstancesPerNode(minInstancesPerNode)
+ .setMinInfoGain(minInfoGain)
+ .setCheckpointInterval(checkpointInterval)
+ .setMaxMemoryInMB(maxMemoryInMB)
+ .setCacheNodeIds(cacheNodeIds)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ .setLabelCol(rFormula.getLabelCol)
+ .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+ if (seed != null && seed.length > 0) dtc.setSeed(seed.toLong)
+
+ val idxToStr = new IndexToString()
+ .setInputCol(PREDICTED_LABEL_INDEX_COL)
+ .setOutputCol(PREDICTED_LABEL_COL)
+ .setLabels(labels)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, dtc, idxToStr))
+ .fit(data)
+
+ new DecisionTreeClassifierWrapper(pipeline, formula, features)
+ }
+
+ override def read: MLReader[DecisionTreeClassifierWrapper] =
+ new DecisionTreeClassifierWrapperReader
+
+ override def load(path: String): DecisionTreeClassifierWrapper = super.load(path)
+
+ class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("formula" -> instance.formula) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] {
+
+ override def load(path: String): DecisionTreeClassifierWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val formula = (rMetadata \ "formula").extract[String]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ new DecisionTreeClassifierWrapper(pipeline, formula, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
new file mode 100644
index 0000000000000..de712d67e6df5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class DecisionTreeRegressorWrapper private (
+ val pipeline: PipelineModel,
+ val formula: String,
+ val features: Array[String]) extends MLWritable {
+
+ private val dtrModel: DecisionTreeRegressionModel =
+ pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel]
+
+ lazy val numFeatures: Int = dtrModel.numFeatures
+ lazy val featureImportances: Vector = dtrModel.featureImportances
+ lazy val maxDepth: Int = dtrModel.getMaxDepth
+
+ def summary: String = dtrModel.toDebugString
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(dtrModel.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new
+ DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this)
+}
+
+private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] {
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ formula: String,
+ maxDepth: Int,
+ maxBins: Int,
+ impurity: String,
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ checkpointInterval: Int,
+ seed: String,
+ maxMemoryInMB: Int,
+ cacheNodeIds: Boolean): DecisionTreeRegressorWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+
+ // assemble and fit the pipeline
+ val dtr = new DecisionTreeRegressor()
+ .setMaxDepth(maxDepth)
+ .setMaxBins(maxBins)
+ .setImpurity(impurity)
+ .setMinInstancesPerNode(minInstancesPerNode)
+ .setMinInfoGain(minInfoGain)
+ .setCheckpointInterval(checkpointInterval)
+ .setMaxMemoryInMB(maxMemoryInMB)
+ .setCacheNodeIds(cacheNodeIds)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ if (seed != null && seed.length > 0) dtr.setSeed(seed.toLong)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, dtr))
+ .fit(data)
+
+ new DecisionTreeRegressorWrapper(pipeline, formula, features)
+ }
+
+ override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader
+
+ override def load(path: String): DecisionTreeRegressorWrapper = super.load(path)
+
+ class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("formula" -> instance.formula) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] {
+
+ override def load(path: String): DecisionTreeRegressorWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val formula = (rMetadata \ "formula").extract[String]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ new DecisionTreeRegressorWrapper(pipeline, formula, features)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index b30ce12bc6cc8..ba6445a730306 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] {
RandomForestRegressorWrapper.load(path)
case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
RandomForestClassifierWrapper.load(path)
+ case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" =>
+ DecisionTreeRegressorWrapper.load(path)
+ case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" =>
+ DecisionTreeClassifierWrapper.load(path)
case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
GBTRegressorWrapper.load(path)
case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 6c39fe5d84865..2b2b5fe49ea32 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -992,7 +992,16 @@ object Matrices {
new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose)
case sm: BSM[Double] =>
// There is no isTranspose flag for sparse matrices in Breeze
- new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
+ val nsm = if (sm.rowIndices.length > sm.activeSize) {
+ // This sparse matrix has trailing zeros.
+ // Remove them by compacting the matrix.
+ val csm = sm.copy
+ csm.compact()
+ csm
+ } else {
+ sm
+ }
+ new SparseMatrix(nsm.rows, nsm.cols, nsm.colPtrs, nsm.rowIndices, nsm.data)
case _ =>
throw new UnsupportedOperationException(
s"Do not support conversion from type ${breeze.getClass.getName}.")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 563756907d201..93c00d80974c3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -513,6 +513,26 @@ class MatricesSuite extends SparkFunSuite {
Matrices.fromBreeze(sum)
}
+ test("Test FromBreeze when Breeze.CSCMatrix.rowIndices has trailing zeros. - SPARK-20687") {
+ // (2, 0, 0)
+ // (2, 0, 0)
+ val mat1Brz = Matrices.sparse(2, 3, Array(0, 2, 2, 2), Array(0, 1), Array(2, 2)).asBreeze
+ // (2, 1E-15, 1E-15)
+ // (2, 1E-15, 1E-15)
+ val mat2Brz = Matrices.sparse(2, 3,
+ Array(0, 2, 4, 6),
+ Array(0, 0, 0, 1, 1, 1),
+ Array(2, 1E-15, 1E-15, 2, 1E-15, 1E-15)).asBreeze
+ val t1Brz = mat1Brz - mat2Brz
+ val t2Brz = mat2Brz - mat1Brz
+ // The following operations raise exceptions on un-patch Matrices.fromBreeze
+ val t1 = Matrices.fromBreeze(t1Brz)
+ val t2 = Matrices.fromBreeze(t2Brz)
+ // t1 == t1Brz && t2 == t2Brz
+ assert((t1.asBreeze - t1Brz).iterator.map((x) => math.abs(x._2)).sum < 1E-15)
+ assert((t2.asBreeze - t2Brz).iterator.map((x) => math.abs(x._2)).sum < 1E-15)
+ }
+
test("row/col iterator") {
val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0))
val sm = dm.toSparse
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 8d25f5b3a771a..955bc9768ce77 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2082,10 +2082,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
"""
A label indexer that maps a string column of labels to an ML column of label indices.
If the input column is numeric, we cast it to string and index the string values.
- The indices are in [0, numLabels), ordered by label frequencies.
- So the most frequent label gets index 0.
+ The indices are in [0, numLabels). By default, this is ordered by label frequencies
+ so the most frequent label gets index 0. The ordering behavior is controlled by
+ setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'.
- >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error')
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
+ ... stringOrderType="frequencyDesc")
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
@@ -2111,26 +2113,45 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
>>> loadedInverter = IndexToString.load(indexToStringPath)
>>> loadedInverter.getLabels() == inverter.getLabels()
True
+ >>> stringIndexer.getStringOrderType()
+ 'frequencyDesc'
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
+ ... stringOrderType="alphabetDesc")
+ >>> model = stringIndexer.fit(stringIndDf)
+ >>> td = model.transform(stringIndDf)
+ >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
+ ... key=lambda x: x[0])
+ [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
.. versionadded:: 1.4.0
"""
+ stringOrderType = Param(Params._dummy(), "stringOrderType",
+ "How to order labels of string column. The first label after " +
+ "ordering is assigned an index of 0. Supported options: " +
+ "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
- def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
+ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
+ stringOrderType="frequencyDesc"):
"""
- __init__(self, inputCol=None, outputCol=None, handleInvalid="error")
+ __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \
+ stringOrderType="frequencyDesc")
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
- self._setDefault(handleInvalid="error")
+ self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
+ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error",
+ stringOrderType="frequencyDesc"):
"""
- setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
+ setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \
+ stringOrderType="frequencyDesc")
Sets params for this StringIndexer.
"""
kwargs = self._input_kwargs
@@ -2139,6 +2160,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
def _create_model(self, java_model):
return StringIndexerModel(java_model)
+ @since("2.3.0")
+ def setStringOrderType(self, value):
+ """
+ Sets the value of :py:attr:`stringOrderType`.
+ """
+ return self._set(stringOrderType=value)
+
+ @since("2.3.0")
+ def getStringOrderType(self):
+ """
+ Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
+ """
+ return self.getOrDefault(self.stringOrderType)
+
class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 3c3fcc8d9b8d8..2d17f95b0c44f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -323,6 +323,14 @@ def numInstances(self):
"""
return self._call_java("numInstances")
+ @property
+ @since("2.2.0")
+ def degreesOfFreedom(self):
+ """
+ Degrees of freedom.
+ """
+ return self._call_java("degreesOfFreedom")
+
@property
@since("2.0.0")
def devianceResiduals(self):
@@ -1565,6 +1573,14 @@ def predictionCol(self):
"""
return self._call_java("predictionCol")
+ @property
+ @since("2.2.0")
+ def numInstances(self):
+ """
+ Number of instances in DataFrame predictions.
+ """
+ return self._call_java("numInstances")
+
@property
@since("2.0.0")
def rank(self):
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index 8f5b97ccb1f85..ac7aec7b0a034 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -185,6 +185,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = {
val environment = Environment.newBuilder()
+ val extraClassPath = conf.getOption("spark.executor.extraClassPath")
+ extraClassPath.foreach { cp =>
+ environment.addVariables(
+ Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build())
+ }
val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
// Set the environment variable through a command prefix
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
index 735c879c63c55..66b8e0a640121 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
@@ -106,6 +106,10 @@ private[spark] class MesosFineGrainedSchedulerBackend(
throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
}
val environment = Environment.newBuilder()
+ sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
+ environment.addVariables(
+ Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build())
+ }
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 75bf780d41424..ed423e7e334b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -366,7 +366,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(TableIdentifier(tableName)),
- Map.empty, logicalPlan, overwrite, false)
+ Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false)
def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index a7bf81e98be8e..bf46a39862131 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -232,9 +232,10 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}
override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(DoubleType, DecimalType))
+ Seq(TypeCollection(LongType, DoubleType, DecimalType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
+ case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
}
@@ -347,9 +348,10 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
}
override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(DoubleType, DecimalType))
+ Seq(TypeCollection(LongType, DoubleType, DecimalType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
+ case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 7a54995453797..d291ca0020838 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -410,17 +410,20 @@ case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) exten
* would have Map('a' -> Some('1'), 'b' -> None).
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
- * @param ifNotExists If true, only write if the table or partition does not exist.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoTable(
table: LogicalPlan,
partition: Map[String, Option[String]],
query: LogicalPlan,
overwrite: Boolean,
- ifNotExists: Boolean)
+ ifPartitionNotExists: Boolean)
extends LogicalPlan {
- assert(overwrite || !ifNotExists)
- assert(partition.values.forall(_.nonEmpty) || !ifNotExists)
+ // IF NOT EXISTS is only valid in INSERT OVERWRITE
+ assert(overwrite || !ifPartitionNotExists)
+ // IF NOT EXISTS is only valid in static partitions
+ assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists)
// We don't want `table` in children as sometimes we don't want to transform it.
override def children: Seq[LogicalPlan] = query :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b97adf7221d18..c5d69c204642e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -303,7 +303,7 @@ object SQLConf {
val HIVE_MANAGE_FILESOURCE_PARTITIONS =
buildConf("spark.sql.hive.manageFilesourcePartitions")
.doc("When true, enable metastore partition management for file source tables as well. " +
- "This includes both datasource and converted Hive tables. When partition managment " +
+ "This includes both datasource and converted Hive tables. When partition management " +
"is enabled, datasource tables store partition in the Hive metastore, and use the " +
"metastore to prune partitions during query planning.")
.booleanConf
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 31047f688600b..0896caeab8d7a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
val plan = testRelation2.select('c).orderBy(Floor('a).asc)
val expected = testRelation2.select(c, a)
- .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c)
+ .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c)
checkAnalysis(plan, expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 1555dd1cf58d4..8ed7a82b943b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -219,6 +219,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType)
}
+ test("cot") {
+ def f: (Double) => Double = (x: Double) => 1 / math.tan(x)
+ testUnary(Cot, f)
+ checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType)
+ }
+
test("atan") {
testUnary(Atan, math.atan)
checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index cca0291b3d5af..d78741d032f38 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -176,14 +176,14 @@ class PlanParserSuite extends PlanTest {
def insert(
partition: Map[String, Option[String]],
overwrite: Boolean = false,
- ifNotExists: Boolean = false): LogicalPlan =
- InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+ ifPartitionNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists)
// Single inserts
assertEqual(s"insert overwrite table s $sql",
insert(Map.empty, overwrite = true))
assertEqual(s"insert overwrite table s partition (e = 1) if not exists $sql",
- insert(Map("e" -> Option("1")), overwrite = true, ifNotExists = true))
+ insert(Map("e" -> Option("1")), overwrite = true, ifPartitionNotExists = true))
assertEqual(s"insert into s $sql",
insert(Map.empty))
assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
@@ -193,9 +193,9 @@ class PlanParserSuite extends PlanTest {
val plan2 = table("t").where('x > 5).select(star())
assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
InsertIntoTable(
- table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union(
+ table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union(
InsertIntoTable(
- table("u"), Map.empty, plan2, false, ifNotExists = false)))
+ table("u"), Map.empty, plan2, false, ifPartitionNotExists = false)))
}
test ("insert with if not exists") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 1732a8e08b73f..b71c5eb843eec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -286,7 +286,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
- ifNotExists = false)
+ ifPartitionNotExists = false)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 375df64d39734..17671ea8685b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -111,93 +111,60 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/**
* @since 1.6.1
- * @deprecated use [[newIntSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newLongSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newDoubleSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newFloatSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newByteSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newShortSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newBooleanSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newStringSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newProductSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
- implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+ def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
/** @since 2.2.0 */
- implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
+ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
// Arrays
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 470307bd940ad..bc7e73ae1ba87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.columnar
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -53,219 +53,288 @@ private[columnar] sealed trait ColumnStats extends Serializable {
/**
* Gathers statistics information from `row(ordinal)`.
*/
- def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- if (row.isNullAt(ordinal)) {
- nullCount += 1
- // 4 bytes for null position
- sizeInBytes += 4
- }
+ def gatherStats(row: InternalRow, ordinal: Int): Unit
+
+ /**
+ * Gathers statistics information on `null`.
+ */
+ def gatherNullStats(): Unit = {
+ nullCount += 1
+ // 4 bytes for null position
+ sizeInBytes += 4
count += 1
}
/**
- * Column statistics represented as a single row, currently including closed lower bound, closed
+ * Column statistics represented as an array, currently including closed lower bound, closed
* upper bound and null count.
*/
- def collectedStatistics: GenericInternalRow
+ def collectedStatistics: Array[Any]
}
/**
* A no-op ColumnStats only used for testing purposes.
*/
-private[columnar] class NoopColumnStats extends ColumnStats {
- override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
+private[columnar] final class NoopColumnStats extends ColumnStats {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ count += 1
+ } else {
+ gatherNullStats
+ }
+ }
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
+ override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, 0L)
}
-private[columnar] class BooleanColumnStats extends ColumnStats {
+private[columnar] final class BooleanColumnStats extends ColumnStats {
protected var upper = false
protected var lower = true
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getBoolean(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += BOOLEAN.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Boolean): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += BOOLEAN.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class ByteColumnStats extends ColumnStats {
+private[columnar] final class ByteColumnStats extends ColumnStats {
protected var upper = Byte.MinValue
protected var lower = Byte.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += BYTE.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Byte): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += BYTE.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class ShortColumnStats extends ColumnStats {
+private[columnar] final class ShortColumnStats extends ColumnStats {
protected var upper = Short.MinValue
protected var lower = Short.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += SHORT.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Short): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += SHORT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class IntColumnStats extends ColumnStats {
+private[columnar] final class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += INT.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Int): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += INT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class LongColumnStats extends ColumnStats {
+private[columnar] final class LongColumnStats extends ColumnStats {
protected var upper = Long.MinValue
protected var lower = Long.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += LONG.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Long): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += LONG.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class FloatColumnStats extends ColumnStats {
+private[columnar] final class FloatColumnStats extends ColumnStats {
protected var upper = Float.MinValue
protected var lower = Float.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += FLOAT.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Float): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += FLOAT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class DoubleColumnStats extends ColumnStats {
+private[columnar] final class DoubleColumnStats extends ColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
- if (value > upper) upper = value
- if (value < lower) lower = value
- sizeInBytes += DOUBLE.defaultSize
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Double): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += DOUBLE.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class StringColumnStats extends ColumnStats {
+private[columnar] final class StringColumnStats extends ColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
- if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
- if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
- sizeInBytes += STRING.actualSize(row, ordinal)
+ val size = STRING.actualSize(row, ordinal)
+ gatherValueStats(value, size)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: UTF8String, size: Int): Unit = {
+ if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
+ if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
+ sizeInBytes += size
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class BinaryColumnStats extends ColumnStats {
+private[columnar] final class BinaryColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- sizeInBytes += BINARY.actualSize(row, ordinal)
+ val size = BINARY.actualSize(row, ordinal)
+ sizeInBytes += size
+ count += 1
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+ override def collectedStatistics: Array[Any] =
+ Array[Any](null, null, nullCount, count, sizeInBytes)
}
-private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
+private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal, precision, scale)
- if (upper == null || value.compareTo(upper) > 0) upper = value
- if (lower == null || value.compareTo(lower) < 0) lower = value
// TODO: this is not right for DecimalType with precision > 18
- sizeInBytes += 8
+ gatherValueStats(value)
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+ def gatherValueStats(value: Decimal): Unit = {
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ sizeInBytes += 8
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
+private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats {
val columnType = ColumnType(dataType)
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- sizeInBytes += columnType.actualSize(row, ordinal)
+ val size = columnType.actualSize(row, ordinal)
+ sizeInBytes += size
+ count += 1
+ } else {
+ gatherNullStats
}
}
- override def collectedStatistics: GenericInternalRow =
- new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+ override def collectedStatistics: Array[Any] =
+ Array[Any](null, null, nullCount, count, sizeInBytes)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 0a9f3e799990f..3486a6bce8180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -123,8 +123,8 @@ case class InMemoryRelation(
batchStats.add(totalSize)
- val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
- .flatMap(_.values))
+ val stats = InternalRow.fromSeq(
+ columnBuilders.flatMap(_.columnStats.collectedStatistics))
CachedBatch(rowCount, columnBuilders.map { builder =>
JavaUtils.bufferToArray(builder.build())
}, stats)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index bb7d1f70b62d9..14c40605ea31c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -430,6 +430,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
+ ifPartitionNotExists = false,
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index d307122b5c70d..21d75a404911b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -142,8 +142,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
parts, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)
- case InsertIntoTable(
- l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) =>
+ case i @ InsertIntoTable(
+ l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) =>
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
@@ -195,6 +195,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitions,
+ i.ifPartitionNotExists,
partitionSchema,
t.bucketSpec,
t.fileFormat,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 19b51d4d9530a..c9d31449d3629 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -37,10 +37,13 @@ import org.apache.spark.sql.execution.command._
* overwrites: when the spec is empty, all partitions are overwritten.
* When it covers a prefix of the partition keys, only partitions matching
* the prefix are overwritten.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
+ ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
@@ -61,8 +64,8 @@ case class InsertIntoHadoopFsRelationCommand(
val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => "\"" + x + "\""
}.mkString(", ")
- throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
- s"cannot save to file.")
+ throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " +
+ "cannot save to file.")
}
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -76,11 +79,12 @@ case class InsertIntoHadoopFsRelationCommand(
var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+ var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty
// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
- val matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
+ matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
@@ -101,8 +105,12 @@ case class InsertIntoHadoopFsRelationCommand(
case (SaveMode.ErrorIfExists, true) =>
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
- deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
- true
+ if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
+ false
+ } else {
+ deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
+ true
+ }
case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
case (SaveMode.Ignore, exists) =>
diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql
index 1920a108c6584..f7167472b05c6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql
@@ -59,3 +59,17 @@ select cot(1);
select cot(null);
select cot(0);
select cot(-1);
+
+-- ceil and ceiling
+select ceiling(0);
+select ceiling(1);
+select ceil(1234567890123456);
+select ceil(12345678901234567);
+select ceiling(1234567890123456);
+select ceiling(12345678901234567);
+
+-- floor
+select floor(0);
+select floor(1);
+select floor(1234567890123456);
+select floor(12345678901234567);
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index abd18211c70d8..fe52005aa91da 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -316,3 +316,83 @@ select cot(-1)
struct
-- !query 37 output
-0.6420926159343306
+
+
+-- !query 38
+select ceiling(0)
+-- !query 38 schema
+struct
+-- !query 38 output
+0
+
+
+-- !query 39
+select ceiling(1)
+-- !query 39 schema
+struct
+-- !query 39 output
+1
+
+
+-- !query 40
+select ceil(1234567890123456)
+-- !query 40 schema
+struct
+-- !query 40 output
+1234567890123456
+
+
+-- !query 41
+select ceil(12345678901234567)
+-- !query 41 schema
+struct
+-- !query 41 output
+12345678901234567
+
+
+-- !query 42
+select ceiling(1234567890123456)
+-- !query 42 schema
+struct
+-- !query 42 output
+1234567890123456
+
+
+-- !query 43
+select ceiling(12345678901234567)
+-- !query 43 schema
+struct
+-- !query 43 output
+12345678901234567
+
+
+-- !query 44
+select floor(0)
+-- !query 44 schema
+struct
+-- !query 44 output
+0
+
+
+-- !query 45
+select floor(1)
+-- !query 45 schema
+struct
+-- !query 45 output
+1
+
+
+-- !query 46
+select floor(1234567890123456)
+-- !query 46 schema
+struct
+-- !query 46 output
+1234567890123456
+
+
+-- !query 47
+select floor(12345678901234567)
+-- !query 47 schema
+struct
+-- !query 47 output
+12345678901234567
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 541565344f758..7e2949ab5aece 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -258,6 +258,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}
+ test("nested sequences") {
+ checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
+ checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
+ }
+
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
index b2d04f7c5a6e3..d4e7e362c6c8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
@@ -18,33 +18,29 @@
package org.apache.spark.sql.execution.columnar
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
- testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
- testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
- testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
- testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
- testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
- testColumnStats(classOf[DoubleColumnStats], DOUBLE,
- createRow(Double.MaxValue, Double.MinValue, 0))
- testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
- testDecimalColumnStats(createRow(null, null, 0))
-
- def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
+ testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
+ testDecimalColumnStats(Array(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: GenericInternalRow): Unit = {
+ initialStatistics: Array[Any]): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+ columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
- assertResult(10, "Wrong null count")(stats.values(2))
- assertResult(20, "Wrong row count")(stats.values(3))
- assertResult(stats.values(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+ assertResult(10, "Wrong null count")(stats(2))
+ assertResult(20, "Wrong row count")(stats(3))
+ assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
@@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite {
}
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
- initialStatistics: GenericInternalRow): Unit = {
+ initialStatistics: Array[Any]): Unit = {
val columnStatsName = classOf[DecimalColumnStats].getSimpleName
val columnType = COMPACT_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new DecimalColumnStats(15, 10)
- columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+ columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}
@@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
- assertResult(10, "Wrong null count")(stats.values(2))
- assertResult(20, "Wrong row count")(stats.values(3))
- assertResult(stats.values(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+ assertResult(10, "Wrong null count")(stats(2))
+ assertResult(20, "Wrong row count")(stats(3))
+ assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 09a5eda6e543f..4f090d545cd18 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -160,9 +160,9 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
*/
object HiveAnalysis extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists)
- if DDLUtils.isHiveTable(relation.tableMeta) =>
- InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists)
+ case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists)
+ if DDLUtils.isHiveTable(r.tableMeta) =>
+ InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
@@ -207,11 +207,11 @@ case class RelationConversions(
override def apply(plan: LogicalPlan): LogicalPlan = {
plan transformUp {
// Write path
- case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists)
+ case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifPartitionNotExists)
// Inserting into partitioned table is not supported in Parquet/Orc data source (yet).
- if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
- !r.isPartitioned && isConvertible(r) =>
- InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists)
+ if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
+ !r.isPartitioned && isConvertible(r) =>
+ InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists)
// Read path
case relation: CatalogRelation
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index 41c6b18e9d794..65e8b4e3c725c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -62,7 +62,7 @@ case class CreateHiveTableAsSelectCommand(
Map(),
query,
overwrite = false,
- ifNotExists = false)).toRdd
+ ifPartitionNotExists = false)).toRdd
} else {
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
@@ -78,7 +78,7 @@ case class CreateHiveTableAsSelectCommand(
Map(),
query,
overwrite = true,
- ifNotExists = false)).toRdd
+ ifPartitionNotExists = false)).toRdd
} catch {
case NonFatal(e) =>
// drop the created table.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 10e17c5f73433..10ce8e3730a0d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -71,14 +71,15 @@ import org.apache.spark.SparkException
* }}}.
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
- * @param ifNotExists If true, only write if the table or partition does not exist.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoHiveTable(
table: CatalogTable,
partition: Map[String, Option[String]],
query: LogicalPlan,
overwrite: Boolean,
- ifNotExists: Boolean) extends RunnableCommand {
+ ifPartitionNotExists: Boolean) extends RunnableCommand {
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
@@ -375,7 +376,7 @@ case class InsertIntoHiveTable(
var doHiveOverwrite = overwrite
- if (oldPart.isEmpty || !ifNotExists) {
+ if (oldPart.isEmpty || !ifPartitionNotExists) {
// SPARK-18107: Insert overwrite runs much slower than hive-client.
// Newer Hive largely improves insert overwrite performance. As Spark uses older Hive
// version and we may not want to catch up new Hive version every time. We delete the
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 7bd3973550043..cc80f2e481cbf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -166,72 +166,54 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE tmp_table")
}
- test("INSERT OVERWRITE - partition IF NOT EXISTS") {
- withTempDir { tmpDir =>
- val table = "table_with_partition"
- withTable(table) {
- val selQuery = s"select c1, p1, p2 from $table"
- sql(
- s"""
- |CREATE TABLE $table(c1 string)
- |PARTITIONED by (p1 string,p2 string)
- |location '${tmpDir.toURI.toString}'
- """.stripMargin)
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b')
- |SELECT 'blarr'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr", "a", "b"))
-
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b')
- |SELECT 'blarr2'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr2", "a", "b"))
+ testPartitionedTable("INSERT OVERWRITE - partition IF NOT EXISTS") { tableName =>
+ val selQuery = s"select a, b, c, d from $tableName"
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3)
+ |SELECT 1, 4
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(1, 2, 3, 4))
- var e = intercept[AnalysisException] {
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2) IF NOT EXISTS
- |SELECT 'blarr3', 'newPartition'
- """.stripMargin)
- }
- assert(e.getMessage.contains(
- "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]"))
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3)
+ |SELECT 5, 6
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
+
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c) IF NOT EXISTS
+ |SELECT 7, 8, 3
+ """.stripMargin)
+ }
+ assert(e.getMessage.contains(
+ "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]"))
- e = intercept[AnalysisException] {
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2) IF NOT EXISTS
- |SELECT 'blarr3', 'b'
- """.stripMargin)
- }
- assert(e.getMessage.contains(
- "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]"))
+ // If the partition already exists, the insert will overwrite the data
+ // unless users specify IF NOT EXISTS
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3) IF NOT EXISTS
+ |SELECT 9, 10
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
- // If the partition already exists, the insert will overwrite the data
- // unless users specify IF NOT EXISTS
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b') IF NOT EXISTS
- |SELECT 'blarr3'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr2", "a", "b"))
- }
- }
+ // ADD PARTITION has the same effect, even if no actual data is inserted.
+ sql(s"ALTER TABLE $tableName ADD PARTITION (b=21, c=31)")
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=21, c=31) IF NOT EXISTS
+ |SELECT 20, 24
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
}
test("Insert ArrayType.containsNull == false") {