diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 51d7fda0cb260..afc59efaef810 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -24,6 +24,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -210,7 +211,7 @@ private static boolean isSymlink(File file) throws IOException { * The unit is also considered the default if the given string does not specify a unit. */ public static long timeStringAs(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); @@ -258,7 +259,7 @@ public static long timeStringAsSec(String str) { * provided, a direct conversion to the provided unit is attempted. */ public static long byteStringAs(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index c226d8f3bc8fa..a25078e262efb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Locale; import java.util.Properties; import com.google.common.primitives.Ints; @@ -75,7 +76,9 @@ public String getModuleName() { } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + public String ioMode() { + return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT); + } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index b38639e854815..dff4f5df68784 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; +import java.util.Locale; import java.util.Set; public enum TaskSorting { @@ -35,7 +36,7 @@ public enum TaskSorting { } public static TaskSorting fromString(String str) { - String lower = str.toLowerCase(); + String lower = str.toLowerCase(Locale.ROOT); for (TaskSorting t: values()) { if (t.alternateNames.contains(lower)) { return t; diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0225fd6056074..99efc4893fda4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -361,7 +361,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def setLogLevel(logLevel: String) { // let's allow lowercase or mixed case too - val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + val upperCased = logLevel.toUpperCase(Locale.ROOT) require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), s"Supplied level $logLevel did not match one of:" + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 539dbb55eeff0..f4a59f069a5f9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket +import java.util.Locale import scala.collection.mutable import scala.util.Properties @@ -319,7 +320,8 @@ object SparkEnv extends Logging { "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ba0096d874567..b2b26ee107c00 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -72,7 +73,7 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) } override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index c216fe477fd15..0cb16f0627b72 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io._ +import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.LZ4BlockOutputStream @@ -66,7 +67,8 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = + shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb8..fce556fd0382c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} @@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR } val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328ad..88bba2fdbd1c6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis } val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 22454e50b14b4..23e31823f4930 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink import java.net.InetSocketAddress -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric } val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } @@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 773e074336cb0..7fa4ba7622980 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} @@ -42,7 +42,7 @@ private[spark] class Slf4jSink( } val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index af9bdefc967ef..aecb3a980e7c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.Locale import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -316,7 +317,7 @@ private[spark] object EventLoggingListener extends Logging { } private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 20cedaf060420..417103436144a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} -import java.util.{NoSuchElementException, Properties} +import java.util.{Locale, NoSuchElementException, Properties} import scala.util.control.NonFatal import scala.xml.{Node, XML} @@ -142,7 +142,8 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) defaultValue: SchedulingMode, fileName: String): SchedulingMode = { - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase + val xmlSchedulingMode = + (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT) val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " + s"Fair Scheduler configuration file: $fileName, using " + s"the default schedulingMode: $defaultValue for pool: $poolName" diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 07aea773fa632..c849a16023a7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -56,8 +56,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val maxTaskFailures: Int, private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) - extends TaskScheduler with Logging -{ + extends TaskScheduler with Logging { import TaskSchedulerImpl._ @@ -135,12 +134,13 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) - val schedulingMode: SchedulingMode = try { - SchedulingMode.withName(schedulingModeConf.toUpperCase) - } catch { - case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") - } + val schedulingMode: SchedulingMode = + try { + SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT)) + } catch { + case e: java.util.NoSuchElementException => + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") + } val rootPool: Pool = new Pool("", schedulingMode, 0, 0) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 6fc66e2374bd9..e15166d11c243 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import java.util.Locale import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -244,7 +245,8 @@ class KryoDeserializationStream( kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + case e: KryoException + if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") => throw new EOFException } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index dbcc6402bc309..6ce3f511e89c7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -42,7 +43,8 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { - threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + threadTrace1.threadName.toLowerCase(Locale.ROOT) < + threadTrace2.threadName.toLowerCase(Locale.ROOT) } else { v1 > v2 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ff9e5e9411ca..3131c4a1eb7d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.Date +import java.util.{Date, Locale} import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -77,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'content': '
0) { - for (exc <- openStreams.values().asScala) { + for (exc <- openStreams.values) { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", - openStreams.values().asScala.head) + openStreams.values.head) } } } @@ -60,8 +67,7 @@ class DebugFilesystem extends LocalFileSystem { override def open(f: Path, bufferSize: Int): FSDataInputStream = { val wrapped: FSDataInputStream = super.open(f, bufferSize) - openStreams.put(wrapped, new Throwable()) - + addOpenStream(wrapped) new FSDataInputStream(wrapped.getWrappedStream) { override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) @@ -98,7 +104,7 @@ class DebugFilesystem extends LocalFileSystem { override def close(): Unit = { wrapped.close() - openStreams.remove(wrapped) + removeOpenStream(wrapped) } override def read(): Int = wrapped.read() diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index e626ed3621d60..58b865969f517 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers @@ -239,7 +239,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.toLowerCase.contains("serializable")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("serializable")) } test("shuffle with different compression settings (SPARK-3426)") { diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 82760fe92f76a..46f9ac6b0273a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import java.util.Locale + import scala.util.Random import org.scalatest.Assertions @@ -130,7 +132,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio val thrown = intercept[IllegalStateException] { sc.broadcast(Seq(1, 2, 3)) } - assert(thrown.getMessage.toLowerCase.contains("stopped")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("stopped")) } test("Forbid broadcasting RDD directly") { diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index e2ba0d2a53d04..b72cd8be24206 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.internal.config +import java.util.Locale import java.util.concurrent.TimeUnit import org.apache.spark.{SparkConf, SparkFunSuite} @@ -132,7 +133,7 @@ class ConfigEntrySuite extends SparkFunSuite { val conf = new SparkConf() val transformationConf = ConfigBuilder(testKey("transformation")) .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .createWithDefault("FOO") assert(conf.get(transformationConf) === "foo") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 13020acdd3dbe..c100803279eaf 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.Locale + import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions @@ -374,8 +376,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite storageLevels.foreach { storageLevel => // Put the block into one of the stores - val blockId = new TestBlockId( - "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) + val blockId = TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase(Locale.ROOT)) val testValue = Array.fill[Byte](blockSize)(1) stores(0).putSingle(blockId, testValue, storageLevel) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 38030e066080f..499d47b13d702 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -37,14 +38,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory should displayed") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index f4c561c737794..bdd148875e38a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} +import java.util.Locale import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -453,8 +454,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(10 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") findAll(cssSelector("tbody tr a")).foreach { link => - link.text.toLowerCase should include ("count") - link.text.toLowerCase should not include "unknown" + link.text.toLowerCase(Locale.ROOT) should include ("count") + link.text.toLowerCase(Locale.ROOT) should not include "unknown" } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index f1be0f6de3ce2..0c3d4caeeabf9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} import java.net.{URI, URL} +import java.util.Locale import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -72,10 +73,10 @@ class UISuite extends SparkFunSuite { eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("storage")) + assert(html.toLowerCase(Locale.ROOT).contains("environment")) + assert(html.toLowerCase(Locale.ROOT).contains("executors")) } } } @@ -85,7 +86,7 @@ class UISuite extends SparkFunSuite { // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString - assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) } } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 1745281c266cc..f736ceed4436f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -203,7 +205,7 @@ object DecisionTreeExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"DecisionTreeExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index db55298d8ea10..ed598d0d7dfae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -140,7 +142,7 @@ object GBTExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"GBTExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index a9e07c0705c92..8fd46c37e2987 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -146,7 +148,7 @@ object RandomForestExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"RandomForestExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index b923e627f2095..cd77ecf990b3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import java.util.Locale + import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -131,7 +133,7 @@ object LDAExample { // Run LDA. val lda = new LDA() - val optimizer = params.algorithm.toLowerCase match { + val optimizer = params.algorithm.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 58b52692b57ce..ab1ce347cbe34 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.UUID +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -74,11 +74,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -115,11 +115,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -192,7 +192,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " @@ -207,7 +207,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) @@ -272,7 +272,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister private def validateGeneralOptions(parameters: Map[String, String]): Unit = { // Validate source options - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedStrategies = caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq @@ -451,8 +451,10 @@ private[kafka010] object KafkaSourceProvider { offsetOptionKey: String, defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { params.get(offsetOptionKey).map(_.trim) match { - case Some(offset) if offset.toLowerCase == "latest" => LatestOffsetRangeLimit - case Some(offset) if offset.toLowerCase == "earliest" => EarliestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => + LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => + EarliestOffsetRangeLimit case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) case None => defaultOffsets } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 68bc3e3e2e9a8..91893df4ec32f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.common.TopicPartition @@ -195,7 +196,7 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 490535623cb36..4bd052d249eca 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig @@ -75,7 +76,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .option("kafka.bootstrap.servers", testUtils.brokerAddress) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "null topic present in the data")) } @@ -92,7 +93,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Ignore) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode ignore not allowed for kafka")) // Test bad save mode Overwrite @@ -103,7 +104,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Overwrite) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode overwrite not allowed for kafka")) } @@ -233,7 +234,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage - .toLowerCase + .toLowerCase(Locale.ROOT) .contains("topic option required when no 'topic' attribute is present")) try { @@ -248,7 +249,8 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) } test("streaming - write data with valid schema but wrong types") { @@ -270,7 +272,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) try { /* value field wrong type */ @@ -284,7 +286,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "value attribute type must be a string or binarytype")) try { @@ -299,7 +301,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binarytype")) } @@ -318,7 +320,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("job aborted")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -330,7 +332,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) ex = intercept[IllegalArgumentException] { @@ -338,7 +340,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 0046ba7e43d13..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -491,7 +491,7 @@ class KafkaSourceSuite extends KafkaSourceTest { reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } @@ -524,7 +524,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option(s"$key", value) reader.load() } - assert(ex.getMessage.toLowerCase.contains("not supported")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("not supported")) } testUnsupportedConfig("kafka.group.id") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 778c06ea16a2b..d2100fc5a4aba 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.kafka010 -import java.{ lang => jl, util => ju } +import java.{lang => jl, util => ju} +import java.util.Locale import scala.collection.JavaConverters._ @@ -93,7 +94,8 @@ private case class Subscribe[K, V]( // but cant seek to a position before poll, because poll is what gets subscription partitions // So, poll, suppress the first exception, then seek val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { @@ -145,7 +147,8 @@ private case class SubscribePattern[K, V]( if (!toSeek.isEmpty) { // work around KAFKA-3370 when reset is none, see explanation in Subscribe above val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d5aef8184fc87..78230725f322e 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong, Number => JNumber} import java.nio.charset.StandardCharsets -import java.util.{List => JList, Map => JMap, Set => JSet} +import java.util.{List => JList, Locale, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -206,7 +206,7 @@ object KafkaUtils { kafkaParams: Map[String, String], topics: Set[String] ): Map[TopicAndPartition, Long] = { - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase(Locale.ROOT)) val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7b56bce41c326..965ce3d6f275f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import java.util.Locale + import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} @@ -654,7 +656,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { override def load(path: String): LogisticRegression = super.load(path) private[classification] val supportedFamilyNames = - Array("auto", "binomial", "multinomial").map(_.toLowerCase) + Array("auto", "binomial", "multinomial").map(_.toLowerCase(Locale.ROOT)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 55720e2d613d9..2f50dc7c85f35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonAST.JObject @@ -173,7 +175,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + (o: String) => + ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index c49416b240181..4bd4aa7113f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.r +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ @@ -91,7 +93,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) // set variancePower and linkPower if family is tweedie; otherwise, set link function - if (family.toLowerCase == "tweedie") { + if (family.toLowerCase(Locale.ROOT) == "tweedie") { glr.setVariancePower(variancePower).setLinkPower(linkPower) } else { glr.setLink(link) @@ -151,7 +153,7 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = if (family.toLowerCase == "tweedie" && + val rAic: Double = if (family.toLowerCase(Locale.ROOT) == "tweedie" && !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 60dd7367053e2..a20ef72446661 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} import java.io.IOException +import java.util.Locale import scala.collection.mutable import scala.reflect.ClassTag @@ -40,8 +41,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -118,10 +118,11 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo "useful in cross-validation or production scenarios, for handling user/item ids the model " + "has not seen in the training data. Supported values: " + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", - (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + (s: String) => + ALSModel.supportedColdStartStrategies.contains(s.toLowerCase(Locale.ROOT))) /** @group expertGetParam */ - def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase(Locale.ROOT) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3be8b533ee3f3..33137b0c0fdec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.regression +import java.util.Locale + import breeze.stats.{distributions => dist} import org.apache.hadoop.fs.Path @@ -57,7 +59,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - (value: String) => supportedFamilyNames.contains(value.toLowerCase)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -99,7 +101,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - (value: String) => supportedLinkNames.contains(value.toLowerCase)) + (value: String) => supportedLinkNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -148,7 +150,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(family).toLowerCase == "tweedie") { + if ($(family).toLowerCase(Locale.ROOT) == "tweedie") { if (isSet(link)) { logWarning("When family is tweedie, use param linkPower to specify link function. " + "Setting param link will take no effect.") @@ -460,13 +462,15 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine */ def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = { val familyObj = Family.fromParams(params) - val linkObj = if ((params.getFamily.toLowerCase != "tweedie" && - params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" && - params.isSet(params.linkPower))) { - Link.fromParams(params) - } else { - familyObj.defaultLink - } + val linkObj = + if ((params.getFamily.toLowerCase(Locale.ROOT) != "tweedie" && + params.isSet(params.link)) || + (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie" && + params.isSet(params.linkPower))) { + Link.fromParams(params) + } else { + familyObj.defaultLink + } new FamilyAndLink(familyObj, linkObj) } } @@ -519,7 +523,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family name and variance power */ def fromParams(params: GeneralizedLinearRegressionBase): Family = { - params.getFamily.toLowerCase match { + params.getFamily.toLowerCase(Locale.ROOT) match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -795,7 +799,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family, link and linkPower */ def fromParams(params: GeneralizedLinearRegressionBase): Link = { - if (params.getFamily.toLowerCase == "tweedie") { + if (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie") { params.getLinkPower match { case 0.0 => Log case 1.0 => Identity @@ -804,7 +808,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine case others => new Power(others) } } else { - params.getLink.toLowerCase match { + params.getLink.toLowerCase(Locale.ROOT) match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log @@ -1253,8 +1257,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val dispersion: Double = if ( - model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) @@ -1357,8 +1361,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( @Since("2.0.0") lazy val pValues: Array[Double] = { if (isNormalSolver) { - if (model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + if (model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } } else { tValues.map { x => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5eb707dfe7bc3..cd1950bd76c05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import java.util.Locale + import scala.util.Try import org.apache.spark.ml.PredictorParams @@ -218,7 +220,8 @@ private[ml] trait TreeClassifierParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "gini") @@ -230,7 +233,7 @@ private[ml] trait TreeClassifierParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -247,7 +250,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeClassifierParams @@ -267,7 +271,8 @@ private[ml] trait TreeRegressorParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -279,7 +284,7 @@ private[ml] trait TreeRegressorParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -295,7 +300,8 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("variance").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams @@ -417,7 +423,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + RandomForestParams.supportedFeatureSubsetStrategies.contains( + value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,13 +438,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait RandomForestClassifierParams @@ -509,7 +516,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { private[ml] object GBTClassifierParams { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("logistic").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { @@ -523,12 +531,13 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "logistic") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldClassificationLoss = { @@ -544,7 +553,8 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam private[ml] object GBTRegressorParams { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { @@ -558,12 +568,13 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "squared") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6c5f529fb8bfd..4aa647236b31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.Locale + import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.{DeveloperApi, Since} @@ -306,7 +308,7 @@ class LDA private ( @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = - optimizerName.toLowerCase match { + optimizerName.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer case "online" => new OnlineLDAOptimizer case other => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 98a3021461eb8..4c7746869dde1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.impurity +import java.util.Locale + import org.apache.spark.annotation.{DeveloperApi, Since} /** @@ -184,7 +186,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity.toLowerCase match { + impurity.toLowerCase(Locale.ROOT) match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 1bec33509580c..ffba99502b148 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -55,7 +55,7 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 30 # seconds duration = .5 @classmethod diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 7f2ec01cc9676..39fc621de7807 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -18,6 +18,7 @@ package org.apache.spark.repl import java.io.File +import java.util.Locale import scala.tools.nsc.GenericRunnerSettings @@ -88,7 +89,7 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { if (SparkSession.hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 3218d221143e5..424bbca123190 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.util.{Properties, UUID} +import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -532,7 +532,7 @@ private[spark] class Client( try { jarsStream.setLevel(0) jarsDir.listFiles().foreach { f => - if (f.isFile && f.getName.toLowerCase().endsWith(".jar") && f.canRead) { + if (f.isFile && f.getName.toLowerCase(Locale.ROOT).endsWith(".jar") && f.canRead) { jarsStream.putNextEntry(new ZipEntry(f.getName)) Files.copy(f, jarsStream) jarsStream.closeEntry() diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 52b5b347fa9c7..1ecb3d1958f43 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -552,6 +552,8 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star @@ -710,7 +712,7 @@ nonReserved | VIEW | REPLACE | IF | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK + | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT @@ -836,6 +838,7 @@ TRANSACTION: 'TRANSACTION'; COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; +IGNORE: 'IGNORE'; IF: 'IF'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 206ae2f0e5eb1..198122759e4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val arrayCls = arrayClassFor(elementType) if (elementNullable) { - Invoke(arrayData, "array", arrayCls) + Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => "toIntArray" @@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection { case other => throw new IllegalStateException("expect primitive array element type " + "but got " + other) } - Invoke(arrayData, primitiveMethod, arrayCls) + Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } case t if t <:< localTypeOf[Seq[_]] => @@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection { Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), + returnNullable = false), schemaFor(valueType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) StaticInvoke( ArrayBasedMapData.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c698ca6a8347c..b0cdef70297cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -617,7 +617,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - lookupTableFromCatalog(u).canonicalized match { + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") case other => i.copy(table = other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f8004ca300ac7..c4827b81e8b63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -83,7 +85,7 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => applyBroadcastHint(h.child, h.parameters.toSet) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 254eedfe77517..3ca9e6a8da5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -167,8 +168,10 @@ object CatalogUtils { */ def maskCredentials(options: Map[String, String]): Map[String, String] = { options.map { - case (key, _) if key.toLowerCase == "password" => (key, "###") - case (key, value) if key.toLowerCase == "url" && value.toLowerCase.contains("password") => + case (key, _) if key.toLowerCase(Locale.ROOT) == "password" => (key, "###") + case (key, value) + if key.toLowerCase(Locale.ROOT) == "url" && + value.toLowerCase(Locale.ROOT).contains("password") => (key, "###") case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6f8c6ee2f0f44..faedf5f91c3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -1098,7 +1099,7 @@ class SessionCatalog( name.database.isEmpty && functionRegistry.functionExists(name.funcName) && !FunctionRegistry.builtin.functionExists(name.funcName) && - !hiveFunctions.contains(name.funcName.toLowerCase) + !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } protected def failFunctionLookup(name: String): Nothing = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala index 8e46b962ff432..67bf2d06c95dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.util.Locale + import org.apache.spark.sql.AnalysisException /** A trait that represents the type of a resourced needed by a function. */ @@ -33,7 +35,7 @@ object ArchiveResource extends FunctionResourceType("archive") object FunctionResourceType { def fromString(resourceType: String): FunctionResourceType = { - resourceType.toLowerCase match { + resourceType.toLowerCase(Locale.ROOT) match { case "jar" => JarResource case "file" => FileResource case "archive" => ArchiveResource diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 360e55d922821..cc0cbba275b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -423,8 +423,15 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - /** Only compare table identifier. */ - override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( + identifier = tableMeta.identifier, + tableType = tableMeta.tableType, + storage = CatalogStorageFormat.empty, + schema = tableMeta.schema, + partitionColumnNames = tableMeta.partitionColumnNames, + bucketSpec = tableMeta.bucketSpec, + createTime = -1 + )) override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6cb..0f8282d3b2f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -89,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -136,16 +136,18 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( @@ -262,17 +264,18 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case StringType => - Invoke(input, "toString", ObjectType(classOf[String])) + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), "array", - ObjectType(classOf[Array[_]])) + ObjectType(classOf[Array[_]]), returnNullable = false) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1db26d9c415a7..b847ef7bfaa97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -184,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = nodeName.toLowerCase + def prettyName: String = nodeName.toLowerCase(Locale.ROOT) protected def flatArguments: Iterator[Any] = productIterator.flatMap { case t: Traversable[_] => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dea5f85cb08cc..c4d47ab2084fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Locale import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -68,7 +69,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) } // name of function in java.lang.Math - def funcName: String = name.toLowerCase + def funcName: String = name.toLowerCase(Locale.ROOT) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") @@ -124,7 +125,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 53842ef348a57..6d94764f1bfac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,25 +225,26 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val assignResult = if (!returnNullable) { + s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + } else { + s""" + if ($funcResult != null) { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ + } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } + $assignResult """ } - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - val code = s""" ${obj.code} boolean ${ev.isNull} = true; @@ -254,7 +255,6 @@ case class Invoke( if (!${ev.isNull}) { $evaluate } - $postNullCheck } """ ev.copy(code = code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1235204591bbd..8acb740f8db8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,6 +90,8 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + // Non-deterministic expressions are not allowed as join conditions. + case e if !e.deterministic => false case l: ListQuery => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b23da537be721..49b779711308f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -60,7 +61,7 @@ abstract class StringRegexExpression extends BinaryExpression } } - override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index b2a3888ff7b08..37190429fc423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -631,7 +633,7 @@ abstract class RankLike extends AggregateWindowFunction { override val updateExpressions = increaseRank +: increaseRowNumber +: children override val evaluateExpression: Expression = rank - override def sql: String = s"${prettyName.toUpperCase}()" + override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}()" def withOrder(order: Seq[Expression]): RankLike } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index fdb7d88d5bd7f..ff6c93ae9815c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream +import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,7 +127,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || @@ -146,7 +147,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index fab7e4c5b1285..e1db1ef5b8695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} +import java.util.Locale import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -31,6 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -1022,6 +1024,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + /** * Create a (windowed) Function expression. */ @@ -1030,7 +1048,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.namedExpression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => @@ -1254,7 +1273,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { val value = string(ctx.STRING) - val valueType = ctx.identifier.getText.toUpperCase + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) try { valueType match { case "DATE" => @@ -1410,7 +1429,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ctx._ val s = value.getText try { - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + val unitText = unit.getText.toLowerCase(Locale.ROOT) + val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { case (u, None) if u.endsWith("s") => // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) @@ -1448,7 +1468,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { case ("boolean", Nil) => BooleanType case ("tinyint" | "byte", Nil) => ByteType case ("smallint" | "short", Nil) => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2d8ec2053a4cb..3008e8cb84659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -359,9 +359,59 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** - * Canonicalized copy of this query plan. + * Returns a plan where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Some nodes should overwrite this to provide proper canonicalize logic. + */ + lazy val canonicalized: PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + preCanonicalized.mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the epxrId too. + id += 1 + ar.withExprId(ExprId(id)) + + case other => normalizeExprId(other) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Do some simple transformation on this plan before canonicalizing. Implementations can override + * this method to provide customized canonicalize logic without rewriting the whole logic. */ - protected lazy val canonicalized: PlanType = this + protected def preCanonicalized: PlanType = this + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { + e.transformUp { + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } /** * Returns true when the given query plan will return the same results as this query plan. @@ -372,49 +422,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * enhancements like caching. However, it is not acceptable to return true if the results could * possibly be different. * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Operators that - * can do better should override this function. + * This function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. */ - def sameResult(plan: PlanType): Boolean = { - val left = this.canonicalized - val right = plan.canonicalized - left.getClass == right.getClass && - left.children.size == right.children.size && - left.cleanArgs == right.cleanArgs && - (left.children, right.children).zipped.forall(_ sameResult _) - } + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() /** * All the attributes that are used for this plan. */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) - - protected def cleanExpression(e: Expression): Expression = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) - case other => - BindReferences.bindReference(other, allAttributes, allowFailures = true) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - def cleanArg(arg: Any): Any = arg match { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e).canonicalized - case other => other - } - - mapProductIterator { - case s: Option[_] => s.map(cleanArg) - case s: Seq[_] => s.map(cleanArg) - case m: Map[_, _] => m.mapValues(cleanArg) - case other => cleanArg(other) - }.toSeq - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae5..90d11d6d91512 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.plans +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { - def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter case "leftouter" | "left" => LeftOuter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b7177c4a2c4e4..9cd5dfd21b160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -67,14 +67,6 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } - } - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 036b6256684cb..6bdcf490ca5c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) - /** * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function * should only be called on analyzed plans since it will throw [[AnalysisException]] for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c91de08ca5ef6..3ad757ebba851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -803,6 +803,8 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78ff6..2ab46dc8330aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any - /** - * Returns true iff this [[BroadcastMode]] generates the same result as `other`. - */ - def compatibleWith(other: BroadcastMode): Boolean + def canonicalized: BroadcastMode } /** @@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows - override def compatibleWith(other: BroadcastMode): Boolean = { - this eq other - } + override def canonicalized: BroadcastMode = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index bdf2baf7361d3..3cd6970ebefbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.sql.streaming.OutputMode /** @@ -47,7 +49,7 @@ private[sql] object InternalOutputModes { def apply(outputMode: String): OutputMode = { - outputMode.toLowerCase match { + outputMode.toLowerCase(Locale.ROOT) match { case "append" => OutputMode.Append case "complete" => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 66dd093bbb691..bb2c5926ae9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Builds a map in which keys are case insensitive. Input map can be accessed for cases where * case-sensitive information is required. The primary constructor is marked private to avoid @@ -26,11 +28,12 @@ package org.apache.spark.sql.catalyst.util class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase)) + val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) - override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase) + override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) - override def contains(k: String): Boolean = keyLowerCasedMap.contains(k.toLowerCase) + override def contains(k: String): Boolean = + keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT)) override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = { new CaseInsensitiveMap(originalMap + kv) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 435fba9d8851c..1377a03d93b7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress._ @@ -38,7 +40,7 @@ object CompressionCodecs { * If it is already a class name, just return it. */ def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase(Locale.ROOT), name) try { // Validate the codec name if (codecName != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f614965520f4a..eb6aad5b2d2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -894,7 +894,7 @@ object DateTimeUtils { * (Because 1970-01-01 is Thursday). */ def getDayOfWeekFromString(string: UTF8String): Int = { - val dowString = string.toString.toUpperCase + val dowString = string.toString.toUpperCase(Locale.ROOT) dowString match { case "SU" | "SUN" | "SUNDAY" => 3 case "MO" | "MON" | "MONDAY" => 4 @@ -951,7 +951,7 @@ object DateTimeUtils { if (format == null) { TRUNC_INVALID } else { - format.toString.toUpperCase match { + format.toString.toUpperCase(Locale.ROOT) match { case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH case _ => TRUNC_INVALID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala index 4565dbde88c88..2beb875d1751d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.spark.internal.Logging sealed trait ParseMode { @@ -45,7 +47,7 @@ object ParseMode extends Logging { /** * Returns the parse mode from the given string. */ - def fromString(mode: String): ParseMode = mode.toUpperCase match { + def fromString(mode: String): ParseMode = mode.toUpperCase(Locale.ROOT) match { case PermissiveMode.name => PermissiveMode case DropMalformedMode.name => DropMalformedMode case FailFastMode.name => FailFastMode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index a7ac6136835a7..812d5ded4bf0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Build a map with String type of key, and it also supports either key case * sensitive or insensitive. @@ -25,7 +27,7 @@ object StringKeyHashMap { def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = if (caseSensitive) { new StringKeyHashMap[T](identity) } else { - new StringKeyHashMap[T](_.toLowerCase) + new StringKeyHashMap[T](_.toLowerCase(Locale.ROOT)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 640c0f189c237..6b0f495033494 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties, TimeZone} +import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -243,7 +243,7 @@ object SQLConf { .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + "uncompressed, snappy, gzip, lzo.") .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") @@ -324,7 +324,7 @@ object SQLConf { "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " + "instead of inferring).") .stringConf - .transform(_.toUpperCase()) + .transform(_.toUpperCase(Locale.ROOT)) .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 26871259c6b6e..520aff5e2b677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -49,7 +51,9 @@ abstract class DataType extends AbstractDataType { /** Name of the type used in JSON serialization. */ def typeName: String = { - this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + this.getClass.getSimpleName + .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT") + .toLowerCase(Locale.ROOT) } private[sql] def jsonValue: JValue = typeName @@ -69,7 +73,7 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString - def sql: String = simpleString.toUpperCase + def sql: String = simpleString.toUpperCase(Locale.ROOT) /** * Check if `this` and `other` are the same data type when ignoring nullability diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 4dc06fc9cf09b..5c4bc5e33c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability @@ -65,7 +67,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" - override def sql: String = typeName.toUpperCase + override def sql: String = typeName.toUpperCase(Locale.ROOT) /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java index e0a54fe30ac7d..d8845e0c838ff 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming; +import java.util.Locale; + import org.junit.Test; public class JavaOutputModeSuite { @@ -24,8 +26,8 @@ public class JavaOutputModeSuite { @Test public void testOutputModes() { OutputMode o1 = OutputMode.Append(); - assert(o1.toString().toLowerCase().contains("append")); + assert(o1.toString().toLowerCase(Locale.ROOT).contains("append")); OutputMode o2 = OutputMode.Complete(); - assert (o2.toString().toLowerCase().contains("complete")); + assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete")); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 1be25ec06c741..82015b1e0671c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -79,7 +81,8 @@ trait AnalysisTest extends PlanTest { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall( + e.getMessage.toLowerCase(Locale.ROOT).contains)) { fail( s"""Exception message should contain the following substrings: | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 8f0a0c0d99d15..c39e372c272b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} -import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -696,7 +697,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBody } expectedMsgs.foreach { m => - if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + if (!e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) { fail(s"Exception message should contain: '$m', " + s"actual exception message:\n\t'${e.getMessage}'") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 7e45028653e36..13bd363c8b692 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -32,7 +34,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("better error message for NPE") { val udf = ScalaUDF( - (s: String) => s.toLowerCase, + (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ccd0b7c5d7f79..950aa2379517e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -241,6 +241,16 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: do not push down non-deterministic filters into join condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = x.join(y).where(Rand(10) > 5.0).analyze + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + test("joins: push to one side after transformCondition") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index d1c6b50536cd2..e7f3b64a71130 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -549,4 +550,11 @@ class ExpressionParserSuite extends PlanTest { val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) } + + test("SPARK-19526 Support ignore nulls keywords for first and last") { + assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala index 201dac35ed2d8..3159b541dca79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +42,7 @@ class InternalOutputModesSuite extends SparkFunSuite { val acceptedModes = Seq("append", "update", "complete") val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } testMode("Xyz") diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 69d797b479159..b203f31a76f03 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.xbean + xbean-asm5-shaded + org.scalacheck scalacheck_${scala.binary.version} @@ -147,11 +151,6 @@ mockito-core test - - org.apache.xbean - xbean-asm5-shaded - test - target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a6..93d565d9fe904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl} +import java.util.Locale import scala.collection.JavaConverters._ @@ -89,7 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - how.toLowerCase match { + how.toLowerCase(Locale.ROOT) match { case "any" => drop(cols.size, cols) case "all" => drop(1, cols) case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") @@ -407,10 +408,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + // nanvl only supports these types + nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) case _ => df.col(quotedColName) } - coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) + coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2b8537c3d4a63..49691c15d0f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -164,7 +164,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 338a6e1314d90..1732a8e08b73f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -66,7 +66,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase(Locale.ROOT) match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore @@ -223,7 +223,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def save(): Unit = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0fe8d87ebd6ba..64755434784a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -108,7 +110,7 @@ class RelationalGroupedDataset protected[sql]( private[this] def strToExpr(expr: String): (Expression => Expression) = { val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase match { + (inputExpr: Expression) => expr.toLowerCase(Locale.ROOT) match { // We special handle a few cases that have alias that are not in function registry. case "avg" | "average" | "mean" => UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index c77328690daec..a26d00411fbaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.util.{Map => JMap} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex @@ -47,17 +47,19 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport - && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { - SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() - } else { - if (enableHiveSupport) { - logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + - "falling back to without Hive support.") + val spark = + if (SparkSession.hiveClassesArePresent && enableHiveSupport && + jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == + "hive") { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to " + + "'hive', falling back to without Hive support.") + } + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } - SparkSession.builder().sparkContext(jsc.sc).getOrCreate() - } setSparkContextSessionConf(spark, sparkConfigMap) spark } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2fa660c4d5e01..3a9132d74ac11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -119,7 +119,7 @@ case class RowDataSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -136,19 +136,17 @@ case class RowDataSourceScanExec( """.stripMargin } - // Ignore rdd when checking results - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata - case _ => false - } + // Only care about `relation` and `metadata` when canonicalizing. + override def preCanonicalized: SparkPlan = + copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) } /** * Physical plan node for scanning data from HadoopFsRelations. * * @param relation The file-based relation to scan. - * @param output Output attributes of the scan. - * @param outputSchema Output schema of the scan. + * @param output Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. @@ -156,7 +154,7 @@ case class RowDataSourceScanExec( case class FileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], - outputSchema: StructType, + requiredSchema: StructType, partitionFilters: Seq[Expression], dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) @@ -267,7 +265,7 @@ case class FileSourceScanExec( val metadata = Map( "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, + "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), @@ -287,7 +285,7 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, + requiredSchema = requiredSchema, filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -515,14 +513,13 @@ case class FileSourceScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: FileSourceScanExec => - val thisPredicates = partitionFilters.map(cleanExpression) - val otherPredicates = other.partitionFilters.map(cleanExpression) - val result = relation == other.relation && metadata == other.metadata && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: FileSourceScanExec = { + FileSourceScanExec( + relation, + output.map(normalizeExprId(_, output)), + requiredSchema, + partitionFilters.map(normalizeExprId(_, output)), + dataFilters.map(normalizeExprId(_, output)), + None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2827b8ac00331..3d1b481a53e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -87,13 +87,6 @@ case class ExternalRDD[T]( override def newInstance(): ExternalRDD.this.type = ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( @@ -162,13 +155,6 @@ case class LogicalRDD( )(session).asInstanceOf[this.type] } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c62..19c68c13262a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -33,7 +33,7 @@ case class LocalTableScanExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private val unsafeRows: Array[InternalRow] = { + private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 2cdfb7a7828c9..1de4f508b89a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,13 +30,19 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = (super.batches :+ + override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + /** + * Optimization batches that are executed before the regular optimization batches (also before + * the finish analysis batch). + */ + def preOptimizationBatches: Seq[Batch] = Nil + /** * Optimization batches that are executed after the regular optimization batches, but before the * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 80afb59b3e88e..20dacf88504f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Locale + import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -103,7 +105,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") } if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase != "noscan") { + if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) @@ -563,7 +565,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } else if (value.STRING != null) { string(value.STRING) } else if (value.booleanValue != null) { - value.getText.toLowerCase + value.getText.toLowerCase(Locale.ROOT) } else { value.getText } @@ -647,7 +649,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { import ctx._ - val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase(Locale.ROOT)) match { case None | Some("all") => (true, true) case Some("system") => (false, true) case Some("user") => (true, false) @@ -677,7 +679,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase + val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT) resourceType match { case "jar" | "file" | "archive" => FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) @@ -959,7 +961,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .flatMap(_.orderedIdentifier.asScala) .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase != "asc") { + if (dir.toLowerCase(Locale.ROOT) != "asc") { operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -1012,13 +1014,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val mayebePaths = remainder(ctx.identifier).trim ctx.op.getType match { case SqlBaseParser.ADD => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "file" => AddFileCommand(mayebePaths) case "jar" => AddJarCommand(mayebePaths) case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) } case SqlBaseParser.LIST => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "files" | "file" => if (mayebePaths.length > 0) { ListFilesCommand(mayebePaths.split("\\s+")) @@ -1305,7 +1307,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { (rowFormatCtx, createFileFormatCtx.fileFormat) match { case (_, ffTable: TableFileFormatContext) => // OK case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case ("sequencefile" | "textfile" | "rcfile") => // OK case fmt => operationNotAllowed( @@ -1313,7 +1315,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { parentCtx) } case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case "textfile" => // OK case fmt => operationNotAllowed( s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c31fd92447c0d..c1e1a631c677e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.{broadcast, TaskContext} +import java.util.Locale + +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -43,7 +45,7 @@ trait CodegenSupport extends SparkPlan { case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" case _: DataSourceScanExec => "scan" - case _ => nodeName.toLowerCase + case _ => nodeName.toLowerCase(Locale.ROOT) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 66a8e044ab879..44278e37c5276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -342,8 +342,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) - // output attributes should not affect the results - override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override lazy val canonicalized: SparkPlan = { + RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) @@ -607,11 +608,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def sameResult(o: SparkPlan): Boolean = o match { - case s: SubqueryExec => child.sameResult(s.child) - case _ => false - } - @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9d3c55060dfb6..55540563ef911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool @@ -764,11 +766,11 @@ object DDLUtils { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase == HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER } def isDatasourceTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase != HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index ea5398761c46d..5687f9332430e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException} @@ -100,7 +102,7 @@ case class DescribeFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. - functionName.funcName.toLowerCase match { + functionName.funcName.toLowerCase(Locale.ROOT) match { case "<>" => Row(s"Function: $functionName") :: Row("Usage: expr1 <> expr2 - " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c9384e44255b8..f3b209deaae5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources -import java.util.{ServiceConfigurationError, ServiceLoader} +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -539,15 +538,16 @@ object DataSource { // Found the data source using fully qualified path dataSource case Failure(error) => - if (provider1.toLowerCase == "orc" || + if (provider1.toLowerCase(Locale.ROOT) == "orc" || provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( "The ORC data source must be used with Hive support enabled") - } else if (provider1.toLowerCase == "avro" || + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + + "Please find an Avro package at " + + "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + @@ -596,8 +596,8 @@ object DataSource { */ def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") - val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") + val optionsWithoutPath = options.filterKeys(_.toLowerCase(Locale.ROOT) != "path") CatalogStorageFormat.empty.copy( - locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath) + locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index 5d97558633146..aea27bd4c4d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -94,27 +94,48 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { // Opaque object that uniquely identifies a shared cache user private type ClientId = Object + private val warnedAboutEviction = new AtomicBoolean(false) // we use a composite cache key in order to distinguish entries inserted by different clients - private val cache: Cache[(ClientId, Path), Array[FileStatus]] = CacheBuilder.newBuilder() - .weigher(new Weigher[(ClientId, Path), Array[FileStatus]] { + private val cache: Cache[(ClientId, Path), Array[FileStatus]] = { + // [[Weigher]].weigh returns Int so we could only cache objects < 2GB + // instead, the weight is divided by this factor (which is smaller + // than the size of one [[FileStatus]]). + // so it will support objects up to 64GB in size. + val weightScale = 32 + val weigher = new Weigher[(ClientId, Path), Array[FileStatus]] { override def weigh(key: (ClientId, Path), value: Array[FileStatus]): Int = { - (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt - }}) - .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { - override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) - : Unit = { + val estimate = (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)) / weightScale + if (estimate > Int.MaxValue) { + logWarning(s"Cached table partition metadata size is too big. Approximating to " + + s"${Int.MaxValue.toLong * weightScale}.") + Int.MaxValue + } else { + estimate.toInt + } + } + } + val removalListener = new RemovalListener[(ClientId, Path), Array[FileStatus]]() { + override def onRemoval( + removed: RemovalNotification[(ClientId, Path), + Array[FileStatus]]): Unit = { if (removed.getCause == RemovalCause.SIZE && - warnedAboutEviction.compareAndSet(false, true)) { + warnedAboutEviction.compareAndSet(false, true)) { logWarning( "Evicting cached table partition metadata from memory due to size constraints " + - "(spark.sql.hive.filesourcePartitionFileCacheSize = " + maxSizeInBytes + " bytes). " + - "This may impact query planning performance.") + "(spark.sql.hive.filesourcePartitionFileCacheSize = " + + maxSizeInBytes + " bytes). This may impact query planning performance.") } - }}) - .maximumWeight(maxSizeInBytes) - .build[(ClientId, Path), Array[FileStatus]]() + } + } + CacheBuilder.newBuilder() + .weigher(weigher) + .removalListener(removalListener) + .maximumWeight(maxSizeInBytes / weightScale) + .build[(ClientId, Path), Array[FileStatus]]() + } + /** * @return a FileStatusCache that does not share any entries with any other client, but does diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 11605dd280569..9897ab73b0da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -245,7 +245,6 @@ object InMemoryFileIndex extends Logging { sessionOpt: Option[SparkSession]): Seq[FileStatus] = { logTrace(s"Listing $path") val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist // Note that statuses only include FileStatus for the files and dirs directly under path, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 4215203960075..3813f953e06a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -43,17 +43,8 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = { - otherPlan.canonicalized match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false - } - } - - // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need - // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's - // expId can be different but the relation is still the same. - override lazy val cleanArgs: Seq[Any] = Seq(relation) + // Only care about relation when canonicalizing. + override def preCanonicalized: LogicalPlan = copy(catalogTable = None) @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 03980922ab38f..c3583209efc56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -194,7 +194,7 @@ object PartitioningUtils { while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left // uncleaned. Here we simply ignore them. - if (currentPath.getName.toLowerCase == "_temporary") { + if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") { return (None, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 4994b8dc80527..62e4c6e4b4ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -71,9 +71,9 @@ class CSVOptions( val param = parameters.getOrElse(paramName, default.toString) if (param == null) { default - } else if (param.toLowerCase == "true") { + } else if (param.toLowerCase(Locale.ROOT) == "true") { true - } else if (param.toLowerCase == "false") { + } else if (param.toLowerCase(Locale.ROOT) == "false") { false } else { throw new Exception(s"$paramName flag can be true or false") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 110d503f91cf4..f8d4a9bb5b81a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} @@ -75,7 +77,7 @@ case class CreateTempViewUsing( } def run(sparkSession: SparkSession): Seq[Row] = { - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 89fe86c038b16..591096d5efd22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} -import java.util.Properties +import java.util.{Locale, Properties} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -55,7 +55,7 @@ class JDBCOptions( */ val asConnectionProperties: Properties = { val properties = new Properties() - parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase)) + parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase(Locale.ROOT))) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -141,7 +141,7 @@ object JDBCOptions { private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - jdbcOptionNames += name.toLowerCase + jdbcOptionNames += name.toLowerCase(Locale.ROOT) name } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 774d1ba194321..5fc3c2753b6cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.util.Try @@ -542,7 +543,7 @@ object JdbcUtils extends Logging { case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition - .toLowerCase.split("\\(")(0) + .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index bdda299a621ac..772d4565de548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -40,9 +42,11 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", + sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = + shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 8b598cc60e778..7abf2ae5166b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -48,7 +50,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { // will catch it and return the original plan, so that the analyzer can report table not // found later. val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass) - if (!isFileFormat || dataSource.className.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (!isFileFormat || + dataSource.className.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index efcaca9338ad6..9c859e41f8762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,10 +48,8 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchangeExec => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false + override lazy val canonicalized: SparkPlan = { + BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } @transient diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9a9597d3733e0..d993ea6c6cef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode { case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) extends LeafExecNode { - override def sameResult(plan: SparkPlan): Boolean = { - // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. - plan.sameResult(child) - } + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b9f6601ea87fe..2dd1dc3da96c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + HashedRelation(rows.iterator, canonicalized.key, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } - } - - override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey - case _ => false + override lazy val canonicalized: HashedRelationBroadcastMode = { + this.copy(key = key.map(_.canonicalized)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 5f548172f5ced..8857966676ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -304,8 +304,8 @@ class StreamExecution( finishTrigger(dataAvailable) if (dataAvailable) { // Update committed offsets. - committedOffsets ++= availableOffsets batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f9dd80230e488..1426728f9b550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable @@ -599,7 +600,7 @@ private[state] class HDFSBackedStateStoreProvider( val nameParts = path.getName.split("\\.") if (nameParts.size == 2) { val version = nameParts(0).toLong - nameParts(1).toLowerCase match { + nameParts(1).toLowerCase(Locale.ROOT) match { case "delta" => // ignore the file otherwise, snapshot file already exists for that batch id if (!versionToFiles.contains(version)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index ca46a1151e3e1..b9515ec7bca2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat case class HiveSerDe( @@ -68,7 +70,7 @@ object HiveSerDe { * @return HiveSerDe associated with the specified source */ def sourceToSerDe(source: String): Option[HiveSerDe] = { - val key = source.toLowerCase match { + val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" case s if s.equals("orcfile") => "orc" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 1ef9d52713d92..0289471bf841a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c3a9cfc08517a..746b2a94f102d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -135,7 +137,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def load(): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f2f700590ca8e..0d2611f9bbcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -230,7 +232,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def start(): StreamingQuery = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 78cf033dd81d7..3ba37addfc8b4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -119,7 +119,7 @@ public void testCommonOperation() { Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); + ls.add(it.next().toUpperCase(Locale.ROOT)); } return ls.iterator(); }, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac332..aa237d0619ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -145,6 +145,20 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil ) + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + checkAnswer( Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) .toDF("a", "b").na.fill(2.34), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 82b707537e45f..541565344f758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(dsBoolean.map(e => !e), false, true) } + test("mapPrimitiveArray") { + val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS() + checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4)) + checkDataset(dsInt.map(e => null: Array[Int]), null, null) + + val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS() + checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D)) + checkDataset(dsDouble.map(e => null: Array[Double]), null, null) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4b69baffab620..d9130fdcfaea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -124,7 +124,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { + if (blackList.exists(t => + testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233dce8..59eaf4d1c29b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ) } - test("compatible BroadcastMode") { + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) - assert(mode1.compatibleWith(mode1)) - assert(!mode1.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode1)) - assert(mode2.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode3)) - assert(mode3.compatibleWith(mode3)) + assert(mode1.canonicalized == mode1.canonicalized) + assert(mode1.canonicalized != mode2.canonicalized) + assert(mode2.canonicalized != mode1.canonicalized) + assert(mode2.canonicalized == mode2.canonicalized) + assert(mode2.canonicalized != mode3.canonicalized) + assert(mode3.canonicalized == mode3.canonicalized) } test("BroadcastExchange same result") { @@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) - assert(!exchange3.sameResult(exchange4)) + assert(exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } @@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) - assert(!exchange4.sameResult(exchange5)) + assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 8bceab39f71d5..1c1931b6a6daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext @@ -24,11 +26,12 @@ class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { val badRule = new SparkStrategy { var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + mode.toLowerCase(Locale.ROOT) match { + case "exception" => throw new AnalysisException(mode) + case "error" => throw new Error(mode) + case _ => Nil + } } spark.experimental.extraStrategies = badRule :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 13202a57851e1..97c61dc8694bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI +import java.util.Locale import scala.reflect.{classTag, ClassTag} @@ -40,8 +41,10 @@ class DDLCommandSuite extends PlanTest { val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) - containsThesePhrases.foreach { p => assert(e.getMessage.toLowerCase.contains(p.toLowerCase)) } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } } private def parseAs[T: ClassTag](query: String): T = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9ebf2dd839a79..fe74ab49f91bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -190,7 +191,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql(query) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { @@ -1813,7 +1814,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTable(tabName) { sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") val message = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") }.getMessage assert(message.contains("SHOW COLUMNS with conflicting databases")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 00f5d5db8f5f4..a9511cbd9e4cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -220,6 +221,21 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) } } + + test("SPARK-20280 - FileStatusCache with a partition with very many files") { + /* fake the size, otherwise we need to allocate 2GB of data to trigger this bug */ + class MyFileStatus extends FileStatus with KnownSizeEstimation { + override def estimatedSize: Long = 1000 * 1000 * 1000 + } + /* files * MyFileStatus.estimatedSize should overflow to negative integer + * so, make it between 2bn and 4bn + */ + val files = (1 to 3).map { i => + new MyFileStatus() + } + val fileStatusCache = FileStatusCache.getOrCreate(spark) + fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 57a0af1dda971..94a2f9a00b3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -300,7 +302,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase(Locale.ROOT)) { compressionCodecFor(path, codec.name()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2b20b9716bf80..b4f3de9961209 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -476,7 +476,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha assert(partDf.schema.map(_.name) === Seq("intField", "stringField")) path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file val df = spark.read.parquet(f.getCanonicalPath) assert(df.schema.map(_.name) === Seq("intField", "stringField")) @@ -484,7 +485,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file but `basePath` is overridden to // the base path containing partitioning directories val df = spark diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index be56c964a18f8..5a0388ec1d1db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.util.Locale + import scala.language.existentials import org.apache.spark.rdd.RDD @@ -76,7 +78,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase * 5) + Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5) } FiltersPushed.list = filters @@ -113,7 +115,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + val c = (a - 1 + 'a').toChar.toString * 5 + + (a - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5 filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } @@ -151,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index f67444fbc49d6..1211242b9fbb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -221,7 +223,7 @@ class FileStreamSinkSuite extends StreamTest { df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) } Seq(mode, "not support").foreach { w => - assert(e.getMessage.toLowerCase.contains(w)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(w)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e5d5b4f328820..f796a4cb4a398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfterAll @@ -105,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte testStream(aggregated, Append)() } Seq("append", "not supported").foreach { m => - assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 05cd3d9f7c2fa..dc2506a48ad00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming.test import java.io.File +import java.util.Locale import java.util.concurrent.TimeUnit import scala.concurrent.duration._ @@ -126,7 +127,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .save() } Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -400,7 +401,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream var e = intercept[IllegalArgumentException](w.foreach(null)) Seq("foreach", "null").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -417,7 +418,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream.partitionBy("value") var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) Seq("foreach", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 7c71e7280c6d3..fb15e7def6dbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.test import java.io.File +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter @@ -144,7 +145,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .start() } Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -276,13 +277,13 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be var w = df.write.partitionBy("value") var e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "bucketing").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -385,7 +386,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be // Reader, with user specified schema, should just apply user schema on the file data val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } - assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains( + "user specified schema not supported")) intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index 1e6ac4f3df475..c5ade65283045 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import javax.net.ssl.SSLServerSocket; @@ -259,12 +260,12 @@ public static TServerSocket getServerSSLSocket(String hiveHost, int portNum, Str if (thriftServerSocket.getServerSocket() instanceof SSLServerSocket) { List sslVersionBlacklistLocal = new ArrayList(); for (String sslVersion : sslVersionBlacklist) { - sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase()); + sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase(Locale.ROOT)); } SSLServerSocket sslServerSocket = (SSLServerSocket) thriftServerSocket.getServerSocket(); List enabledProtocols = new ArrayList(); for (String protocol : sslServerSocket.getEnabledProtocols()) { - if (sslVersionBlacklistLocal.contains(protocol.toLowerCase())) { + if (sslVersionBlacklistLocal.contains(protocol.toLowerCase(Locale.ROOT))) { LOG.debug("Disabling SSL Protocol: " + protocol); } else { enabledProtocols.add(protocol); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java index ab3ac6285aa02..ad4dfd75f4707 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java @@ -19,6 +19,7 @@ package org.apache.hive.service.auth; import java.util.HashMap; +import java.util.Locale; import java.util.Map; /** @@ -52,7 +53,7 @@ public String toString() { public static SaslQOP fromString(String str) { if (str != null) { - str = str.toLowerCase(); + str = str.toLowerCase(Locale.ROOT); } SaslQOP saslQOP = STR_TO_ENUM.get(str); if (saslQOP == null) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java index a96d2ac371cd3..7752ec03a29b7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java @@ -19,6 +19,7 @@ package org.apache.hive.service.cli; import java.sql.DatabaseMetaData; +import java.util.Locale; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hive.service.cli.thrift.TTypeId; @@ -160,7 +161,7 @@ public static Type getType(String name) { if (name.equalsIgnoreCase(type.name)) { return type; } else if (type.isQualifiedType() || type.isComplexType()) { - if (name.toUpperCase().startsWith(type.name)) { + if (name.toUpperCase(Locale.ROOT).startsWith(type.name)) { return type; } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 14553601b1d58..5e4734ad3ad25 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -294,7 +294,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext) private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.toLowerCase(Locale.ENGLISH).equals("http") + transportMode.toLowerCase(Locale.ROOT).equals("http") } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 1bc5c3c62f045..d5cc3b3855045 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -302,7 +302,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() - val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { sessionState.close() System.exit(0) } - if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || cmd_trimmed.startsWith("!") || isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f0e35dff57f7b..806f2be5faeb0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.lang.reflect.InvocationTargetException import java.util +import java.util.Locale import scala.collection.mutable import scala.util.control.NonFatal @@ -499,7 +500,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // We can't use `filterKeys` here, as the map returned by `filterKeys` is not serializable, // while `CatalogTable` should be serializable. val propsWithoutPath = table.storage.properties.filter { - case (k, v) => k.toLowerCase != "path" + case (k, v) => k.toLowerCase(Locale.ROOT) != "path" } table.storage.copy(properties = propsWithoutPath ++ newPath.map("path" -> _)) } @@ -1060,7 +1061,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Hive's metastore is case insensitive. However, Hive's createFunction does // not normalize the function name (unlike the getFunction part). So, // we are normalizing the function name. - val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) requireFunctionNotExists(db, functionName) val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9e3eb2dd8234a..c917f110b90f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -143,7 +145,7 @@ private[sql] class HiveSessionCatalog( // This function is not in functionRegistry, let's try to load it as a Hive's // built-in function. // Hive is case insensitive. - val functionName = funcName.unquotedString.toLowerCase + val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { failFunctionLookup(funcName.unquotedString) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0465e9c031e27..09a5eda6e543f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException +import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -184,14 +185,14 @@ case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { private def isConvertible(relation: CatalogRelation): Boolean = { - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET)) || - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)) + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } private def convert(relation: CatalogRelation): LogicalRelation = { - if (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet")) { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + if (serde.contains("parquet")) { val options = Map(ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index afc2bf85334d0..3de60c7fc1318 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -21,6 +21,7 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap @@ -338,7 +339,7 @@ private[spark] object HiveUtils extends Logging { logWarning(s"Hive jar path '$path' does not exist.") Nil } else { - files.filter(_.getName.toLowerCase.endsWith(".jar")) + files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")) } case path => new File(path) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 56ccac32a8d88..387ec4f967233 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -153,7 +154,7 @@ private[hive] class HiveClientImpl( hadoopConf.iterator().asScala.foreach { entry => val key = entry.getKey val value = entry.getValue - if (key.toLowerCase.contains("password")) { + if (key.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") } else { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") @@ -168,7 +169,7 @@ private[hive] class HiveClientImpl( hiveConf.setClassLoader(initClassLoader) // 2: we set all spark confs to this hiveConf. sparkConf.getAll.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Spark config to Hive Conf: $k=xxx") } else { logDebug(s"Applying Spark config to Hive Conf: $k=$v") @@ -177,7 +178,7 @@ private[hive] class HiveClientImpl( } // 3: we set all entries in config to this hiveConf. extraConfig.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying extra config to HiveConf: $k=xxx") } else { logDebug(s"Applying extra config to HiveConf: $k=$v") @@ -622,7 +623,7 @@ private[hive] class HiveClientImpl( */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + if (cmd.toLowerCase(Locale.ROOT).startsWith("set")) { logDebug(s"Changing config: $cmd") } try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 2e35f39839488..7abb9f06b1310 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -505,8 +505,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { val resourceUris = f.resources.map { resource => - new ResourceUri( - ResourceType.valueOf(resource.resourceType.resourceType.toUpperCase()), resource.uri) + new ResourceUri(ResourceType.valueOf( + resource.resourceType.resourceType.toUpperCase(Locale.ROOT)), resource.uri) } new HiveFunction( f.identifier.funcName, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 192851028031b..5c515515b9b9c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -29,7 +31,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase) + val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase(Locale.ROOT)) val inputFormat = parameters.get(INPUT_FORMAT) val outputFormat = parameters.get(OUTPUT_FORMAT) @@ -75,7 +77,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) } def serdeProperties: Map[String, String] = parameters.filterKeys { - k => !lowerCasedOptionNames.contains(k.toLowerCase) + k => !lowerCasedOptionNames.contains(k.toLowerCase(Locale.ROOT)) }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } } @@ -83,7 +85,7 @@ object HiveOptions { private val lowerCasedOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - lowerCasedOptionNames += name.toLowerCase + lowerCasedOptionNames += name.toLowerCase(Locale.ROOT) name } @@ -99,5 +101,5 @@ object HiveOptions { // The following typo is inherited from Hive... "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", - "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase -> v } + "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 28f074849c0f5..fab0d7fa84827 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -72,7 +72,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -80,20 +80,22 @@ case class HiveTableScanExec( BindReferences.bindReference(pred, relation.partitionCols) } - // Create a local copy of hadoopConf,so that scan specific modifications should not impact - // other queries - @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() - - @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) - @transient private val tableDesc = new TableDesc( + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, hiveQlTable.getOutputFormatClass, hiveQlTable.getMetadata) - // append columns ids and names before broadcast - addColumnMetadataToConf(hadoopConf) + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } - @transient private val hadoopReader = new HadoopTableReader( + @transient private lazy val hadoopReader = new HadoopTableReader( output, relation.partitionCols, tableDesc, @@ -104,7 +106,7 @@ case class HiveTableScanExec( Cast(Literal(value), dataType).eval(null) } - private def addColumnMetadataToConf(hiveConf: Configuration) { + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { // Specifies needed column IDs for those non-partitioning columns. val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) @@ -198,18 +200,13 @@ case class HiveTableScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: HiveTableScanExec => - val thisPredicates = partitionPruningPred.map(cleanExpression) - val otherPredicates = other.partitionPruningPred.map(cleanExpression) - - val result = relation.sameResult(other.relation) && - output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: HiveTableScanExec = { + val input: AttributeSeq = relation.output + HiveTableScanExec( + requestedAttributes.map(normalizeExprId(_, input)), + relation.canonicalized.asInstanceOf[CatalogRelation], + partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index ccaa568dcce2a..043eb69818ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.orc +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -41,9 +43,9 @@ private[orc] class OrcOptions(@transient private val parameters: CaseInsensitive val codecName = parameters .get("compression") .orElse(orcCompressionConf) - .getOrElse("snappy").toLowerCase + .getOrElse("snappy").toLowerCase(Locale.ROOT) if (!shortOrcCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 490e02d0bd541..59cc6605a1243 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.net.URI +import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier @@ -49,7 +50,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def analyzeCreateTable(sql: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index e48ce2304d086..319d02613f00a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -18,18 +18,15 @@ package org.apache.spark.sql.hive import java.io.File -import java.util.concurrent.{Executors, TimeUnit} import scala.util.Random import org.scalatest.BeforeAndAfterEach -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index e45cf977bfaa2..abe5d835719b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util +import java.util.Locale import scala.util.control.NonFatal @@ -299,10 +300,11 @@ abstract class HiveComparisonTest // thus the tables referenced in those DDL commands cannot be extracted for use by our // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES // command expect these tables to exist. - val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + val hasShowTableCommand = + queryList.exists(_.toLowerCase(Locale.ROOT).contains("show tables")) for (table <- Seq("src", "srcpart")) { val hasMatchingQuery = queryList.exists { query => - val normalizedQuery = query.toLowerCase.stripSuffix(";") + val normalizedQuery = query.toLowerCase(Locale.ROOT).stripSuffix(";") normalizedQuery.endsWith(table) || normalizedQuery.contains(s"from $table") || normalizedQuery.contains(s"from default.$table") @@ -444,7 +446,7 @@ abstract class HiveComparisonTest "create table", "drop index" ) - !queryList.map(_.toLowerCase).exists { query => + !queryList.map(_.toLowerCase(Locale.ROOT)).exists { query => excludedSubstrings.exists(s => query.contains(s)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 65a902fc5438e..cf33760360724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -80,7 +80,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd private def assertUnsupportedFeature(body: => Unit): Unit = { val e = intercept[ParseException] { body } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } // Testing the Broadcast based join for cartesian join (cross join) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d012797e19926..75f3744ff35be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -475,13 +476,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case None => // OK. } // Also make sure that the format and serde are as desired. - assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) - assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.inputFormat.get.toLowerCase(Locale.ROOT).contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase(Locale.ROOT).contains(format)) val serde = catalogTable.storage.serde.get format match { case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) - case _ => assert(serde.toLowerCase.contains(format)) + case _ => assert(serde.toLowerCase(Locale.ROOT).contains(format)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9a760e2947d0b..931f015f03b6f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import java.util.Locale + import scala.reflect.ClassTag import org.apache.spark.SparkContext @@ -60,7 +62,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) .split("(?=[A-Z])") .filter(_.nonEmpty) .mkString(" ") - .toLowerCase + .toLowerCase(Locale.ROOT) .capitalize s"$newName [$id]" } @@ -74,7 +76,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) protected[streaming] override val baseScope: Option[String] = { val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } - .getOrElse(name.toLowerCase) + .getOrElse(name.toLowerCase(Locale.ROOT)) Some(new RDDOperationScope(scopeName).toJson) } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 80513de4ee117..90d1f8c5035b3 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -101,7 +101,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { String out = ""; while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out = out + in.next().toUpperCase(Locale.ROOT); } return Arrays.asList(out).iterator(); }); @@ -806,7 +806,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 96f8d9593d630..6c86cacec8279 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -267,7 +267,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out.append(in.next().toUpperCase(Locale.ENGLISH)); + out.append(in.next().toUpperCase(Locale.ROOT)); } return Arrays.asList(out.toString()).iterator(); }); @@ -1315,7 +1315,7 @@ public void testMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream mapped = - pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5645996de5a69..eb996c93ff381 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} +import java.util.Locale import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -745,7 +746,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val ex = intercept[IllegalStateException] { body } - assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(expectedErrorMsg)) } }